改进模型注册,动态注册

This commit is contained in:
czzhangheng 2025-12-14 17:48:37 +08:00
parent 19fd7622a3
commit 5827554c73
34 changed files with 390 additions and 104 deletions

View File

@ -0,0 +1,7 @@
[
{
"name": "AGCRN",
"module": "model.AGCRN.AGCRN",
"entry": "AGCRN"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "ARIMA",
"module": "model.ARIMA.ARIMA",
"entry": "ARIMA"
}
]

View File

@ -0,0 +1,17 @@
[
{
"name": "ASTRA",
"module": "model.ASTRA.astra",
"entry": "ASTRA"
},
{
"name": "ASTRA_v2",
"module": "model.ASTRA.astrav2",
"entry": "ASTRA"
},
{
"name": "ASTRA_v3",
"module": "model.ASTRA.astrav3",
"entry": "ASTRA"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "DCRNN",
"module": "model.DCRNN.dcrnn_model",
"entry": "DCRNNModel"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "DDGCRN",
"module": "model.DDGCRN.DDGCRN",
"entry": "DDGCRN"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "DSANET",
"module": "model.DSANET.DSANET",
"entry": "DSANet"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "EXP",
"module": "model.EXP.EXP32",
"entry": "EXP"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "GWN",
"module": "model.GWN.GraphWaveNet",
"entry": "gwnet"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "HI",
"module": "model.HI.HI",
"entry": "HI"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "Informer",
"module": "model.Informer.model",
"entry": "Informer"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "MTGNN",
"module": "model.MTGNN.MTGNN",
"entry": "gtnet"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "MegaCRN",
"module": "model.MegaCRN.MegaCRNModel",
"entry": "MegaCRNModel"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "NLT",
"module": "model.NLT.HierAttnLstm",
"entry": "HierAttnLstm"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "PDG2SEQ",
"module": "model.PDG2SEQ.PDG2Seqb",
"entry": "PDG2Seq"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "PatchTST",
"module": "model.PatchTST.PatchTST",
"entry": "Model"
}
]

109
model/README.md Normal file
View File

@ -0,0 +1,109 @@
# 模型注册说明
## 概述
本项目使用基于配置文件的模型注册机制,每个模型目录下的 `model_config.json` 文件用于注册该目录下的模型。
## model_config.json 格式
### 基本格式
每个 `model_config.json` 文件是一个 JSON 数组,包含一个或多个模型配置对象:
```json
[
{
"name": "模型名称",
"module": "模型模块路径",
"entry": "模型入口点"
}
]
```
### 字段说明
- **name**: 模型的唯一标识符,用于在配置文件中选择模型
- **module**: 模型所在的模块路径,使用 Python 导入格式
- **entry**: 模型的入口点,可以是类名或函数名
### 示例
#### 1. 单个模型
```json
[
{
"name": "DDGCRN",
"module": "model.DDGCRN.DDGCRN",
"entry": "DDGCRN"
}
]
```
#### 2. 多个模型(同一目录下的不同版本)
```json
[
{
"name": "ASTRA",
"module": "model.ASTRA.astra",
"entry": "ASTRA"
},
{
"name": "ASTRA_v2",
"module": "model.ASTRA.astrav2",
"entry": "ASTRA"
},
{
"name": "ASTRA_v3",
"module": "model.ASTRA.astrav3",
"entry": "ASTRA"
}
]
```
#### 3. 函数模型
```json
[
{
"name": "STGNCDE",
"module": "model.STGNCDE.Make_model",
"entry": "make_model"
}
]
```
## 添加新模型
1. 在 `model` 目录下创建模型目录
2. 在该目录下实现模型代码
3. 创建 `model_config.json` 文件,配置模型信息
4. 在配置文件中使用模型名称选择模型
## 注意事项
1. 模型名称必须唯一,不允许重复
2. 模块路径必须是正确的 Python 导入路径
3. 入口点必须是模块中存在的类或函数
4. 配置文件必须是有效的 JSON 格式
5. 每个模型目录下只能有一个 `model_config.json` 文件
## 模型选择
在配置文件中,通过 `basic.model` 字段指定要使用的模型名称:
```json
{
"basic": {
"model": "ASTRA"
},
"model": {
// 模型特定配置
}
}
```
## 冲突检测
系统会自动检测模型名冲突,如有冲突会抛出 `AssertionError` 并显示冲突信息。

View File

@ -0,0 +1,7 @@
[
{
"name": "REPST",
"module": "model.REPST.repst",
"entry": "repst"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "STAEFormer",
"module": "model.STAEFormer.STAEFormer",
"entry": "STAEformer"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "STAWnet",
"module": "model.STAWnet.STAWnet",
"entry": "STAWnet"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "STFGNN",
"module": "model.STFGNN.STFGNN",
"entry": "STFGNN"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "STGCN",
"module": "model.STGCN.models",
"entry": "STGCNChebGraphConv"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "STGNCDE",
"module": "model.STGNCDE.Make_model",
"entry": "make_model"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "STGNRDE",
"module": "model.STGNRDE.Make_model",
"entry": "make_model"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "STGODE",
"module": "model.STGODE.STGODE",
"entry": "ODEGCN"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "STID",
"module": "model.STID.STID",
"entry": "STID"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "STIDGCN",
"module": "model.STIDGCN.STIDGCN",
"entry": "STIDGCN"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "STMLP",
"module": "model.STMLP.STMLP",
"entry": "STMLP"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "STSGCN",
"module": "model.STSGCN.STSGCN",
"entry": "STSGCN"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "ST_SSL",
"module": "model.ST_SSL.ST_SSL",
"entry": "STSSLModel"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "TCN",
"module": "model.TCN.TCN",
"entry": "TemporalConvNet"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "TWDGCN",
"module": "model.TWDGCN.TWDGCN",
"entry": "TWDGCN"
}
]

View File

@ -0,0 +1,7 @@
[
{
"name": "iTransformer",
"module": "model.iTransformer.iTransformer",
"entry": "iTransformer"
}
]

View File

@ -1,107 +1,57 @@
from model.DDGCRN.DDGCRN import DDGCRN
from model.HI import HI
from model.TWDGCN.TWDGCN import TWDGCN
from model.AGCRN.AGCRN import AGCRN
from model.NLT.HierAttnLstm import HierAttnLstm
from model.STGNCDE.Make_model import make_model
from model.DSANET.DSANET import DSANet
from model.STGCN.models import STGCNChebGraphConv
from model.DCRNN.dcrnn_model import DCRNNModel
from model.ARIMA.ARIMA import ARIMA
from model.TCN.TCN import TemporalConvNet
from model.GWN.GraphWaveNet import gwnet
from model.STFGNN.STFGNN import STFGNN
from model.STSGCN.STSGCN import STSGCN
from model.STGODE.STGODE import ODEGCN
from model.PDG2SEQ.PDG2Seqb import PDG2Seq
from model.STMLP.STMLP import STMLP
from model.STIDGCN.STIDGCN import STIDGCN
from model.STID.STID import STID
from model.STAEFormer.STAEFormer import STAEformer
from model.EXP.EXP32 import EXP as EXP
from model.MegaCRN.MegaCRNModel import MegaCRNModel
from model.ST_SSL.ST_SSL import STSSLModel
from model.STGNRDE.Make_model import make_model as make_nrde_model
from model.STAWnet.STAWnet import STAWnet
from model.REPST.repst import repst as REPST
from model.ASTRA.astra import ASTRA as ASTRA
from model.ASTRA.astrav2 import ASTRA as ASTRAv2
from model.ASTRA.astrav3 import ASTRA as ASTRAv3
from model.iTransformer.iTransformer import iTransformer
from model.Informer.model import Informer
from model.HI.HI import HI
from model.PatchTST.PatchTST import Model as PatchTST
from model.MTGNN.MTGNN import gtnet as MTGNN
import os
import json
import importlib
import sys
from pathlib import Path
class ModelRegistry:
def __init__(self):
self.models = {}
self.model_configs = {}
self.model_dir = Path(__file__).parent
self._load_model_configs()
def _load_model_configs(self):
"""加载所有model_config.json文件"""
# 直接遍历所有model_config.json文件
for config_path in self.model_dir.rglob("model_config.json"):
# 读取配置文件
with open(config_path, 'r') as f:
configs = json.load(f)
# 处理每个模型配置
for config in configs:
model_name = config["name"]
# 检查模型名冲突
assert model_name not in self.model_configs, f"模型名冲突: {model_name} 已存在,冲突文件: {config_path}"
self.model_configs[model_name] = config
def _load_model(self, model_name):
"""动态加载模型"""
if model_name not in self.model_configs:
raise ValueError(f"模型 {model_name} 未注册")
config = self.model_configs[model_name]
module = importlib.import_module(config["module"])
model_cls = getattr(module, config["entry"])
self.models[model_name] = model_cls
def get_model(self, model_name):
"""获取模型类或函数"""
if model_name not in self.models:
self._load_model(model_name)
return self.models[model_name]
# 初始化模型注册表
model_registry = ModelRegistry()
def model_selector(config):
model_name = config["basic"]["model"]
model_config = config["model"]
match model_name:
case "DDGCRN":
return DDGCRN(model_config)
case "TWDGCN":
return TWDGCN(model_config)
case "AGCRN":
return AGCRN(model_config)
case "NLT":
return HierAttnLstm(model_config)
case "STGNCDE":
return make_model(model_config)
case "DSANET":
return DSANet(model_config)
case "STGCN":
return STGCNChebGraphConv(model_config)
case "DCRNN":
return DCRNNModel(model_config)
case "ARIMA":
return ARIMA(model_config)
case "TCN":
return TemporalConvNet(model_config)
case "GWN":
return gwnet(model_config)
case "STFGNN":
return STFGNN(model_config)
case "STSGCN":
return STSGCN(model_config)
case "STGODE":
return ODEGCN(model_config)
case "PDG2SEQ":
return PDG2Seq(model_config)
case "STMLP":
return STMLP(model_config)
case "STIDGCN":
return STIDGCN(model_config)
case "STID":
return STID(model_config)
case "STAEFormer":
return STAEformer(model_config)
case "EXP":
return EXP(model_config)
case "MegaCRN":
return MegaCRNModel(model_config)
case "ST_SSL":
return STSSLModel(model_config)
case "STGNRDE":
return make_nrde_model(model_config)
case "STAWnet":
return STAWnet(model_config)
case "REPST":
return REPST(model_config)
case "ASTRA":
return ASTRA(model_config)
case "ASTRA_v2":
return ASTRAv2(model_config)
case "ASTRA_v3":
return ASTRAv3(model_config)
case "iTransformer":
return iTransformer(model_config)
case "Informer":
return Informer(model_config)
case "HI":
return HI(model_config)
case "PatchTST":
return PatchTST(model_config)
case "MTGNN":
return MTGNN(model_config)
model_cls = model_registry.get_model(model_name)
model = model_cls(model_config)
# print(f"\n=== 模型选择结果 ===")
print(f"选择的模型: {model_name}")
print(f"模型入口: {model_registry.model_configs[model_name]['module']}:{model_registry.model_configs[model_name]['entry']}")
return model

View File

@ -13,7 +13,7 @@ def read_config(config_path):
# 全局配置
device = "cuda:0" # 指定设备为cuda:0
seed = 2023 # 随机种子
epochs = 1
epochs = 100
# 拷贝项
config["basic"]["device"] = device
@ -63,14 +63,14 @@ def run(config):
if __name__ == "__main__":
# 指定模型
model_list = ["Informer"]
model_list = ["iTransformer"]
# 指定数据集
dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"]
# dataset_list = ["PEMS-BAY"]
# 我的调试开关,不做测试就填 str(False)
# os.environ["TRY"] = str(False)
os.environ["TRY"] = str(True)
os.environ["TRY"] = str(False)
for model in model_list:
for dataset in dataset_list: