改进模型注册,动态注册
This commit is contained in:
parent
19fd7622a3
commit
5827554c73
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "AGCRN",
|
||||
"module": "model.AGCRN.AGCRN",
|
||||
"entry": "AGCRN"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "ARIMA",
|
||||
"module": "model.ARIMA.ARIMA",
|
||||
"entry": "ARIMA"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
[
|
||||
{
|
||||
"name": "ASTRA",
|
||||
"module": "model.ASTRA.astra",
|
||||
"entry": "ASTRA"
|
||||
},
|
||||
{
|
||||
"name": "ASTRA_v2",
|
||||
"module": "model.ASTRA.astrav2",
|
||||
"entry": "ASTRA"
|
||||
},
|
||||
{
|
||||
"name": "ASTRA_v3",
|
||||
"module": "model.ASTRA.astrav3",
|
||||
"entry": "ASTRA"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "DCRNN",
|
||||
"module": "model.DCRNN.dcrnn_model",
|
||||
"entry": "DCRNNModel"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "DDGCRN",
|
||||
"module": "model.DDGCRN.DDGCRN",
|
||||
"entry": "DDGCRN"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "DSANET",
|
||||
"module": "model.DSANET.DSANET",
|
||||
"entry": "DSANet"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "EXP",
|
||||
"module": "model.EXP.EXP32",
|
||||
"entry": "EXP"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "GWN",
|
||||
"module": "model.GWN.GraphWaveNet",
|
||||
"entry": "gwnet"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "HI",
|
||||
"module": "model.HI.HI",
|
||||
"entry": "HI"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "Informer",
|
||||
"module": "model.Informer.model",
|
||||
"entry": "Informer"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "MTGNN",
|
||||
"module": "model.MTGNN.MTGNN",
|
||||
"entry": "gtnet"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "MegaCRN",
|
||||
"module": "model.MegaCRN.MegaCRNModel",
|
||||
"entry": "MegaCRNModel"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "NLT",
|
||||
"module": "model.NLT.HierAttnLstm",
|
||||
"entry": "HierAttnLstm"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "PDG2SEQ",
|
||||
"module": "model.PDG2SEQ.PDG2Seqb",
|
||||
"entry": "PDG2Seq"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "PatchTST",
|
||||
"module": "model.PatchTST.PatchTST",
|
||||
"entry": "Model"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
# 模型注册说明
|
||||
|
||||
## 概述
|
||||
|
||||
本项目使用基于配置文件的模型注册机制,每个模型目录下的 `model_config.json` 文件用于注册该目录下的模型。
|
||||
|
||||
## model_config.json 格式
|
||||
|
||||
### 基本格式
|
||||
|
||||
每个 `model_config.json` 文件是一个 JSON 数组,包含一个或多个模型配置对象:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"name": "模型名称",
|
||||
"module": "模型模块路径",
|
||||
"entry": "模型入口点"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### 字段说明
|
||||
|
||||
- **name**: 模型的唯一标识符,用于在配置文件中选择模型
|
||||
- **module**: 模型所在的模块路径,使用 Python 导入格式
|
||||
- **entry**: 模型的入口点,可以是类名或函数名
|
||||
|
||||
### 示例
|
||||
|
||||
#### 1. 单个模型
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"name": "DDGCRN",
|
||||
"module": "model.DDGCRN.DDGCRN",
|
||||
"entry": "DDGCRN"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
#### 2. 多个模型(同一目录下的不同版本)
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"name": "ASTRA",
|
||||
"module": "model.ASTRA.astra",
|
||||
"entry": "ASTRA"
|
||||
},
|
||||
{
|
||||
"name": "ASTRA_v2",
|
||||
"module": "model.ASTRA.astrav2",
|
||||
"entry": "ASTRA"
|
||||
},
|
||||
{
|
||||
"name": "ASTRA_v3",
|
||||
"module": "model.ASTRA.astrav3",
|
||||
"entry": "ASTRA"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
#### 3. 函数模型
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"name": "STGNCDE",
|
||||
"module": "model.STGNCDE.Make_model",
|
||||
"entry": "make_model"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## 添加新模型
|
||||
|
||||
1. 在 `model` 目录下创建模型目录
|
||||
2. 在该目录下实现模型代码
|
||||
3. 创建 `model_config.json` 文件,配置模型信息
|
||||
4. 在配置文件中使用模型名称选择模型
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 模型名称必须唯一,不允许重复
|
||||
2. 模块路径必须是正确的 Python 导入路径
|
||||
3. 入口点必须是模块中存在的类或函数
|
||||
4. 配置文件必须是有效的 JSON 格式
|
||||
5. 每个模型目录下只能有一个 `model_config.json` 文件
|
||||
|
||||
## 模型选择
|
||||
|
||||
在配置文件中,通过 `basic.model` 字段指定要使用的模型名称:
|
||||
|
||||
```json
|
||||
{
|
||||
"basic": {
|
||||
"model": "ASTRA"
|
||||
},
|
||||
"model": {
|
||||
// 模型特定配置
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 冲突检测
|
||||
|
||||
系统会自动检测模型名冲突,如有冲突会抛出 `AssertionError` 并显示冲突信息。
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "REPST",
|
||||
"module": "model.REPST.repst",
|
||||
"entry": "repst"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "STAEFormer",
|
||||
"module": "model.STAEFormer.STAEFormer",
|
||||
"entry": "STAEformer"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "STAWnet",
|
||||
"module": "model.STAWnet.STAWnet",
|
||||
"entry": "STAWnet"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "STFGNN",
|
||||
"module": "model.STFGNN.STFGNN",
|
||||
"entry": "STFGNN"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "STGCN",
|
||||
"module": "model.STGCN.models",
|
||||
"entry": "STGCNChebGraphConv"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "STGNCDE",
|
||||
"module": "model.STGNCDE.Make_model",
|
||||
"entry": "make_model"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "STGNRDE",
|
||||
"module": "model.STGNRDE.Make_model",
|
||||
"entry": "make_model"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "STGODE",
|
||||
"module": "model.STGODE.STGODE",
|
||||
"entry": "ODEGCN"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "STID",
|
||||
"module": "model.STID.STID",
|
||||
"entry": "STID"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "STIDGCN",
|
||||
"module": "model.STIDGCN.STIDGCN",
|
||||
"entry": "STIDGCN"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "STMLP",
|
||||
"module": "model.STMLP.STMLP",
|
||||
"entry": "STMLP"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "STSGCN",
|
||||
"module": "model.STSGCN.STSGCN",
|
||||
"entry": "STSGCN"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "ST_SSL",
|
||||
"module": "model.ST_SSL.ST_SSL",
|
||||
"entry": "STSSLModel"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "TCN",
|
||||
"module": "model.TCN.TCN",
|
||||
"entry": "TemporalConvNet"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "TWDGCN",
|
||||
"module": "model.TWDGCN.TWDGCN",
|
||||
"entry": "TWDGCN"
|
||||
}
|
||||
]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "iTransformer",
|
||||
"module": "model.iTransformer.iTransformer",
|
||||
"entry": "iTransformer"
|
||||
}
|
||||
]
|
||||
|
|
@ -1,107 +1,57 @@
|
|||
from model.DDGCRN.DDGCRN import DDGCRN
|
||||
from model.HI import HI
|
||||
from model.TWDGCN.TWDGCN import TWDGCN
|
||||
from model.AGCRN.AGCRN import AGCRN
|
||||
from model.NLT.HierAttnLstm import HierAttnLstm
|
||||
from model.STGNCDE.Make_model import make_model
|
||||
from model.DSANET.DSANET import DSANet
|
||||
from model.STGCN.models import STGCNChebGraphConv
|
||||
from model.DCRNN.dcrnn_model import DCRNNModel
|
||||
from model.ARIMA.ARIMA import ARIMA
|
||||
from model.TCN.TCN import TemporalConvNet
|
||||
from model.GWN.GraphWaveNet import gwnet
|
||||
from model.STFGNN.STFGNN import STFGNN
|
||||
from model.STSGCN.STSGCN import STSGCN
|
||||
from model.STGODE.STGODE import ODEGCN
|
||||
from model.PDG2SEQ.PDG2Seqb import PDG2Seq
|
||||
from model.STMLP.STMLP import STMLP
|
||||
from model.STIDGCN.STIDGCN import STIDGCN
|
||||
from model.STID.STID import STID
|
||||
from model.STAEFormer.STAEFormer import STAEformer
|
||||
from model.EXP.EXP32 import EXP as EXP
|
||||
from model.MegaCRN.MegaCRNModel import MegaCRNModel
|
||||
from model.ST_SSL.ST_SSL import STSSLModel
|
||||
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
||||
from model.STAWnet.STAWnet import STAWnet
|
||||
from model.REPST.repst import repst as REPST
|
||||
from model.ASTRA.astra import ASTRA as ASTRA
|
||||
from model.ASTRA.astrav2 import ASTRA as ASTRAv2
|
||||
from model.ASTRA.astrav3 import ASTRA as ASTRAv3
|
||||
from model.iTransformer.iTransformer import iTransformer
|
||||
from model.Informer.model import Informer
|
||||
from model.HI.HI import HI
|
||||
from model.PatchTST.PatchTST import Model as PatchTST
|
||||
from model.MTGNN.MTGNN import gtnet as MTGNN
|
||||
import os
|
||||
import json
|
||||
import importlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
class ModelRegistry:
|
||||
def __init__(self):
|
||||
self.models = {}
|
||||
self.model_configs = {}
|
||||
self.model_dir = Path(__file__).parent
|
||||
self._load_model_configs()
|
||||
|
||||
def _load_model_configs(self):
|
||||
"""加载所有model_config.json文件"""
|
||||
# 直接遍历所有model_config.json文件
|
||||
for config_path in self.model_dir.rglob("model_config.json"):
|
||||
# 读取配置文件
|
||||
with open(config_path, 'r') as f:
|
||||
configs = json.load(f)
|
||||
|
||||
# 处理每个模型配置
|
||||
for config in configs:
|
||||
model_name = config["name"]
|
||||
# 检查模型名冲突
|
||||
assert model_name not in self.model_configs, f"模型名冲突: {model_name} 已存在,冲突文件: {config_path}"
|
||||
self.model_configs[model_name] = config
|
||||
|
||||
def _load_model(self, model_name):
|
||||
"""动态加载模型"""
|
||||
if model_name not in self.model_configs:
|
||||
raise ValueError(f"模型 {model_name} 未注册")
|
||||
|
||||
config = self.model_configs[model_name]
|
||||
module = importlib.import_module(config["module"])
|
||||
model_cls = getattr(module, config["entry"])
|
||||
self.models[model_name] = model_cls
|
||||
|
||||
def get_model(self, model_name):
|
||||
"""获取模型类或函数"""
|
||||
if model_name not in self.models:
|
||||
self._load_model(model_name)
|
||||
return self.models[model_name]
|
||||
|
||||
# 初始化模型注册表
|
||||
model_registry = ModelRegistry()
|
||||
|
||||
def model_selector(config):
|
||||
model_name = config["basic"]["model"]
|
||||
model_config = config["model"]
|
||||
match model_name:
|
||||
case "DDGCRN":
|
||||
return DDGCRN(model_config)
|
||||
case "TWDGCN":
|
||||
return TWDGCN(model_config)
|
||||
case "AGCRN":
|
||||
return AGCRN(model_config)
|
||||
case "NLT":
|
||||
return HierAttnLstm(model_config)
|
||||
case "STGNCDE":
|
||||
return make_model(model_config)
|
||||
case "DSANET":
|
||||
return DSANet(model_config)
|
||||
case "STGCN":
|
||||
return STGCNChebGraphConv(model_config)
|
||||
case "DCRNN":
|
||||
return DCRNNModel(model_config)
|
||||
case "ARIMA":
|
||||
return ARIMA(model_config)
|
||||
case "TCN":
|
||||
return TemporalConvNet(model_config)
|
||||
case "GWN":
|
||||
return gwnet(model_config)
|
||||
case "STFGNN":
|
||||
return STFGNN(model_config)
|
||||
case "STSGCN":
|
||||
return STSGCN(model_config)
|
||||
case "STGODE":
|
||||
return ODEGCN(model_config)
|
||||
case "PDG2SEQ":
|
||||
return PDG2Seq(model_config)
|
||||
case "STMLP":
|
||||
return STMLP(model_config)
|
||||
case "STIDGCN":
|
||||
return STIDGCN(model_config)
|
||||
case "STID":
|
||||
return STID(model_config)
|
||||
case "STAEFormer":
|
||||
return STAEformer(model_config)
|
||||
case "EXP":
|
||||
return EXP(model_config)
|
||||
case "MegaCRN":
|
||||
return MegaCRNModel(model_config)
|
||||
case "ST_SSL":
|
||||
return STSSLModel(model_config)
|
||||
case "STGNRDE":
|
||||
return make_nrde_model(model_config)
|
||||
case "STAWnet":
|
||||
return STAWnet(model_config)
|
||||
case "REPST":
|
||||
return REPST(model_config)
|
||||
case "ASTRA":
|
||||
return ASTRA(model_config)
|
||||
case "ASTRA_v2":
|
||||
return ASTRAv2(model_config)
|
||||
case "ASTRA_v3":
|
||||
return ASTRAv3(model_config)
|
||||
case "iTransformer":
|
||||
return iTransformer(model_config)
|
||||
case "Informer":
|
||||
return Informer(model_config)
|
||||
case "HI":
|
||||
return HI(model_config)
|
||||
case "PatchTST":
|
||||
return PatchTST(model_config)
|
||||
case "MTGNN":
|
||||
return MTGNN(model_config)
|
||||
|
||||
model_cls = model_registry.get_model(model_name)
|
||||
model = model_cls(model_config)
|
||||
# print(f"\n=== 模型选择结果 ===")
|
||||
print(f"选择的模型: {model_name}")
|
||||
print(f"模型入口: {model_registry.model_configs[model_name]['module']}:{model_registry.model_configs[model_name]['entry']}")
|
||||
return model
|
||||
|
|
|
|||
6
train.py
6
train.py
|
|
@ -13,7 +13,7 @@ def read_config(config_path):
|
|||
# 全局配置
|
||||
device = "cuda:0" # 指定设备为cuda:0
|
||||
seed = 2023 # 随机种子
|
||||
epochs = 1
|
||||
epochs = 100
|
||||
|
||||
# 拷贝项
|
||||
config["basic"]["device"] = device
|
||||
|
|
@ -63,14 +63,14 @@ def run(config):
|
|||
|
||||
if __name__ == "__main__":
|
||||
# 指定模型
|
||||
model_list = ["Informer"]
|
||||
model_list = ["iTransformer"]
|
||||
# 指定数据集
|
||||
dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"]
|
||||
# dataset_list = ["PEMS-BAY"]
|
||||
|
||||
# 我的调试开关,不做测试就填 str(False)
|
||||
# os.environ["TRY"] = str(False)
|
||||
os.environ["TRY"] = str(True)
|
||||
os.environ["TRY"] = str(False)
|
||||
|
||||
for model in model_list:
|
||||
for dataset in dataset_list:
|
||||
|
|
|
|||
Loading…
Reference in New Issue