From 5827554c735f1a0491ca32232ec73ab6f0c01b9b Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sun, 14 Dec 2025 17:48:37 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E8=BF=9B=E6=A8=A1=E5=9E=8B=E6=B3=A8?= =?UTF-8?q?=E5=86=8C=EF=BC=8C=E5=8A=A8=E6=80=81=E6=B3=A8=E5=86=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/AGCRN/model_config.json | 7 ++ model/ARIMA/model_config.json | 7 ++ model/ASTRA/model_config.json | 17 +++ model/DCRNN/model_config.json | 7 ++ model/DDGCRN/model_config.json | 7 ++ model/DSANET/model_config.json | 7 ++ model/EXP/model_config.json | 7 ++ model/GWN/model_config.json | 7 ++ model/HI/model_config.json | 7 ++ model/Informer/model_config.json | 7 ++ model/MTGNN/model_config.json | 7 ++ model/MegaCRN/model_config.json | 7 ++ model/NLT/model_config.json | 7 ++ model/PDG2SEQ/model_config.json | 7 ++ model/PatchTST/model_config.json | 7 ++ model/README.md | 109 +++++++++++++++++++ model/REPST/model_config.json | 7 ++ model/STAEFormer/model_config.json | 7 ++ model/STAWnet/model_config.json | 7 ++ model/STFGNN/model_config.json | 7 ++ model/STGCN/model_config.json | 7 ++ model/STGNCDE/model_config.json | 7 ++ model/STGNRDE/model_config.json | 7 ++ model/STGODE/model_config.json | 7 ++ model/STID/model_config.json | 7 ++ model/STIDGCN/model_config.json | 7 ++ model/STMLP/model_config.json | 7 ++ model/STSGCN/model_config.json | 7 ++ model/ST_SSL/model_config.json | 7 ++ model/TCN/model_config.json | 7 ++ model/TWDGCN/model_config.json | 7 ++ model/iTransformer/model_config.json | 7 ++ model/model_selector.py | 152 +++++++++------------------ train.py | 6 +- 34 files changed, 390 insertions(+), 104 deletions(-) create mode 100644 model/AGCRN/model_config.json create mode 100644 model/ARIMA/model_config.json create mode 100644 model/ASTRA/model_config.json create mode 100644 model/DCRNN/model_config.json create mode 100644 model/DDGCRN/model_config.json create mode 100644 model/DSANET/model_config.json create mode 100644 model/EXP/model_config.json create mode 100644 model/GWN/model_config.json create mode 100644 model/HI/model_config.json create mode 100644 model/Informer/model_config.json create mode 100644 model/MTGNN/model_config.json create mode 100644 model/MegaCRN/model_config.json create mode 100644 model/NLT/model_config.json create mode 100644 model/PDG2SEQ/model_config.json create mode 100644 model/PatchTST/model_config.json create mode 100644 model/README.md create mode 100644 model/REPST/model_config.json create mode 100644 model/STAEFormer/model_config.json create mode 100644 model/STAWnet/model_config.json create mode 100644 model/STFGNN/model_config.json create mode 100644 model/STGCN/model_config.json create mode 100644 model/STGNCDE/model_config.json create mode 100644 model/STGNRDE/model_config.json create mode 100644 model/STGODE/model_config.json create mode 100644 model/STID/model_config.json create mode 100644 model/STIDGCN/model_config.json create mode 100644 model/STMLP/model_config.json create mode 100644 model/STSGCN/model_config.json create mode 100644 model/ST_SSL/model_config.json create mode 100644 model/TCN/model_config.json create mode 100644 model/TWDGCN/model_config.json create mode 100644 model/iTransformer/model_config.json diff --git a/model/AGCRN/model_config.json b/model/AGCRN/model_config.json new file mode 100644 index 0000000..e1c9b61 --- /dev/null +++ b/model/AGCRN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "AGCRN", + "module": "model.AGCRN.AGCRN", + "entry": "AGCRN" + } +] \ No newline at end of file diff --git a/model/ARIMA/model_config.json b/model/ARIMA/model_config.json new file mode 100644 index 0000000..9b33c5c --- /dev/null +++ b/model/ARIMA/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "ARIMA", + "module": "model.ARIMA.ARIMA", + "entry": "ARIMA" + } +] \ No newline at end of file diff --git a/model/ASTRA/model_config.json b/model/ASTRA/model_config.json new file mode 100644 index 0000000..3cd0064 --- /dev/null +++ b/model/ASTRA/model_config.json @@ -0,0 +1,17 @@ +[ + { + "name": "ASTRA", + "module": "model.ASTRA.astra", + "entry": "ASTRA" + }, + { + "name": "ASTRA_v2", + "module": "model.ASTRA.astrav2", + "entry": "ASTRA" + }, + { + "name": "ASTRA_v3", + "module": "model.ASTRA.astrav3", + "entry": "ASTRA" + } +] \ No newline at end of file diff --git a/model/DCRNN/model_config.json b/model/DCRNN/model_config.json new file mode 100644 index 0000000..c92b599 --- /dev/null +++ b/model/DCRNN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "DCRNN", + "module": "model.DCRNN.dcrnn_model", + "entry": "DCRNNModel" + } +] \ No newline at end of file diff --git a/model/DDGCRN/model_config.json b/model/DDGCRN/model_config.json new file mode 100644 index 0000000..a07fc3a --- /dev/null +++ b/model/DDGCRN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "DDGCRN", + "module": "model.DDGCRN.DDGCRN", + "entry": "DDGCRN" + } +] \ No newline at end of file diff --git a/model/DSANET/model_config.json b/model/DSANET/model_config.json new file mode 100644 index 0000000..5624f8a --- /dev/null +++ b/model/DSANET/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "DSANET", + "module": "model.DSANET.DSANET", + "entry": "DSANet" + } +] \ No newline at end of file diff --git a/model/EXP/model_config.json b/model/EXP/model_config.json new file mode 100644 index 0000000..bdf39b7 --- /dev/null +++ b/model/EXP/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "EXP", + "module": "model.EXP.EXP32", + "entry": "EXP" + } +] \ No newline at end of file diff --git a/model/GWN/model_config.json b/model/GWN/model_config.json new file mode 100644 index 0000000..38d05b4 --- /dev/null +++ b/model/GWN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "GWN", + "module": "model.GWN.GraphWaveNet", + "entry": "gwnet" + } +] \ No newline at end of file diff --git a/model/HI/model_config.json b/model/HI/model_config.json new file mode 100644 index 0000000..3071864 --- /dev/null +++ b/model/HI/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "HI", + "module": "model.HI.HI", + "entry": "HI" + } +] \ No newline at end of file diff --git a/model/Informer/model_config.json b/model/Informer/model_config.json new file mode 100644 index 0000000..3836cd0 --- /dev/null +++ b/model/Informer/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "Informer", + "module": "model.Informer.model", + "entry": "Informer" + } +] \ No newline at end of file diff --git a/model/MTGNN/model_config.json b/model/MTGNN/model_config.json new file mode 100644 index 0000000..94aa32c --- /dev/null +++ b/model/MTGNN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "MTGNN", + "module": "model.MTGNN.MTGNN", + "entry": "gtnet" + } +] \ No newline at end of file diff --git a/model/MegaCRN/model_config.json b/model/MegaCRN/model_config.json new file mode 100644 index 0000000..e8c0599 --- /dev/null +++ b/model/MegaCRN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "MegaCRN", + "module": "model.MegaCRN.MegaCRNModel", + "entry": "MegaCRNModel" + } +] \ No newline at end of file diff --git a/model/NLT/model_config.json b/model/NLT/model_config.json new file mode 100644 index 0000000..a99a6b1 --- /dev/null +++ b/model/NLT/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "NLT", + "module": "model.NLT.HierAttnLstm", + "entry": "HierAttnLstm" + } +] \ No newline at end of file diff --git a/model/PDG2SEQ/model_config.json b/model/PDG2SEQ/model_config.json new file mode 100644 index 0000000..783f3bf --- /dev/null +++ b/model/PDG2SEQ/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "PDG2SEQ", + "module": "model.PDG2SEQ.PDG2Seqb", + "entry": "PDG2Seq" + } +] \ No newline at end of file diff --git a/model/PatchTST/model_config.json b/model/PatchTST/model_config.json new file mode 100644 index 0000000..d613fbb --- /dev/null +++ b/model/PatchTST/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "PatchTST", + "module": "model.PatchTST.PatchTST", + "entry": "Model" + } +] \ No newline at end of file diff --git a/model/README.md b/model/README.md new file mode 100644 index 0000000..24dd3ca --- /dev/null +++ b/model/README.md @@ -0,0 +1,109 @@ +# 模型注册说明 + +## 概述 + +本项目使用基于配置文件的模型注册机制,每个模型目录下的 `model_config.json` 文件用于注册该目录下的模型。 + +## model_config.json 格式 + +### 基本格式 + +每个 `model_config.json` 文件是一个 JSON 数组,包含一个或多个模型配置对象: + +```json +[ + { + "name": "模型名称", + "module": "模型模块路径", + "entry": "模型入口点" + } +] +``` + +### 字段说明 + +- **name**: 模型的唯一标识符,用于在配置文件中选择模型 +- **module**: 模型所在的模块路径,使用 Python 导入格式 +- **entry**: 模型的入口点,可以是类名或函数名 + +### 示例 + +#### 1. 单个模型 + +```json +[ + { + "name": "DDGCRN", + "module": "model.DDGCRN.DDGCRN", + "entry": "DDGCRN" + } +] +``` + +#### 2. 多个模型(同一目录下的不同版本) + +```json +[ + { + "name": "ASTRA", + "module": "model.ASTRA.astra", + "entry": "ASTRA" + }, + { + "name": "ASTRA_v2", + "module": "model.ASTRA.astrav2", + "entry": "ASTRA" + }, + { + "name": "ASTRA_v3", + "module": "model.ASTRA.astrav3", + "entry": "ASTRA" + } +] +``` + +#### 3. 函数模型 + +```json +[ + { + "name": "STGNCDE", + "module": "model.STGNCDE.Make_model", + "entry": "make_model" + } +] +``` + +## 添加新模型 + +1. 在 `model` 目录下创建模型目录 +2. 在该目录下实现模型代码 +3. 创建 `model_config.json` 文件,配置模型信息 +4. 在配置文件中使用模型名称选择模型 + +## 注意事项 + +1. 模型名称必须唯一,不允许重复 +2. 模块路径必须是正确的 Python 导入路径 +3. 入口点必须是模块中存在的类或函数 +4. 配置文件必须是有效的 JSON 格式 +5. 每个模型目录下只能有一个 `model_config.json` 文件 + +## 模型选择 + +在配置文件中,通过 `basic.model` 字段指定要使用的模型名称: + +```json +{ + "basic": { + "model": "ASTRA" + }, + "model": { + // 模型特定配置 + } +} +``` + +## 冲突检测 + +系统会自动检测模型名冲突,如有冲突会抛出 `AssertionError` 并显示冲突信息。 diff --git a/model/REPST/model_config.json b/model/REPST/model_config.json new file mode 100644 index 0000000..5bdfce6 --- /dev/null +++ b/model/REPST/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "REPST", + "module": "model.REPST.repst", + "entry": "repst" + } +] \ No newline at end of file diff --git a/model/STAEFormer/model_config.json b/model/STAEFormer/model_config.json new file mode 100644 index 0000000..8823a88 --- /dev/null +++ b/model/STAEFormer/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STAEFormer", + "module": "model.STAEFormer.STAEFormer", + "entry": "STAEformer" + } +] \ No newline at end of file diff --git a/model/STAWnet/model_config.json b/model/STAWnet/model_config.json new file mode 100644 index 0000000..0e83de9 --- /dev/null +++ b/model/STAWnet/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STAWnet", + "module": "model.STAWnet.STAWnet", + "entry": "STAWnet" + } +] \ No newline at end of file diff --git a/model/STFGNN/model_config.json b/model/STFGNN/model_config.json new file mode 100644 index 0000000..ef5bd7e --- /dev/null +++ b/model/STFGNN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STFGNN", + "module": "model.STFGNN.STFGNN", + "entry": "STFGNN" + } +] \ No newline at end of file diff --git a/model/STGCN/model_config.json b/model/STGCN/model_config.json new file mode 100644 index 0000000..af5885a --- /dev/null +++ b/model/STGCN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STGCN", + "module": "model.STGCN.models", + "entry": "STGCNChebGraphConv" + } +] \ No newline at end of file diff --git a/model/STGNCDE/model_config.json b/model/STGNCDE/model_config.json new file mode 100644 index 0000000..3ec8745 --- /dev/null +++ b/model/STGNCDE/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STGNCDE", + "module": "model.STGNCDE.Make_model", + "entry": "make_model" + } +] \ No newline at end of file diff --git a/model/STGNRDE/model_config.json b/model/STGNRDE/model_config.json new file mode 100644 index 0000000..ec655a8 --- /dev/null +++ b/model/STGNRDE/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STGNRDE", + "module": "model.STGNRDE.Make_model", + "entry": "make_model" + } +] \ No newline at end of file diff --git a/model/STGODE/model_config.json b/model/STGODE/model_config.json new file mode 100644 index 0000000..d6a03e2 --- /dev/null +++ b/model/STGODE/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STGODE", + "module": "model.STGODE.STGODE", + "entry": "ODEGCN" + } +] \ No newline at end of file diff --git a/model/STID/model_config.json b/model/STID/model_config.json new file mode 100644 index 0000000..1a39d87 --- /dev/null +++ b/model/STID/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STID", + "module": "model.STID.STID", + "entry": "STID" + } +] \ No newline at end of file diff --git a/model/STIDGCN/model_config.json b/model/STIDGCN/model_config.json new file mode 100644 index 0000000..a986383 --- /dev/null +++ b/model/STIDGCN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STIDGCN", + "module": "model.STIDGCN.STIDGCN", + "entry": "STIDGCN" + } +] \ No newline at end of file diff --git a/model/STMLP/model_config.json b/model/STMLP/model_config.json new file mode 100644 index 0000000..e7cfb08 --- /dev/null +++ b/model/STMLP/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STMLP", + "module": "model.STMLP.STMLP", + "entry": "STMLP" + } +] \ No newline at end of file diff --git a/model/STSGCN/model_config.json b/model/STSGCN/model_config.json new file mode 100644 index 0000000..a5e2b4d --- /dev/null +++ b/model/STSGCN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STSGCN", + "module": "model.STSGCN.STSGCN", + "entry": "STSGCN" + } +] \ No newline at end of file diff --git a/model/ST_SSL/model_config.json b/model/ST_SSL/model_config.json new file mode 100644 index 0000000..8bbfb74 --- /dev/null +++ b/model/ST_SSL/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "ST_SSL", + "module": "model.ST_SSL.ST_SSL", + "entry": "STSSLModel" + } +] \ No newline at end of file diff --git a/model/TCN/model_config.json b/model/TCN/model_config.json new file mode 100644 index 0000000..d083150 --- /dev/null +++ b/model/TCN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "TCN", + "module": "model.TCN.TCN", + "entry": "TemporalConvNet" + } +] \ No newline at end of file diff --git a/model/TWDGCN/model_config.json b/model/TWDGCN/model_config.json new file mode 100644 index 0000000..92f3167 --- /dev/null +++ b/model/TWDGCN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "TWDGCN", + "module": "model.TWDGCN.TWDGCN", + "entry": "TWDGCN" + } +] \ No newline at end of file diff --git a/model/iTransformer/model_config.json b/model/iTransformer/model_config.json new file mode 100644 index 0000000..79c8db5 --- /dev/null +++ b/model/iTransformer/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "iTransformer", + "module": "model.iTransformer.iTransformer", + "entry": "iTransformer" + } +] \ No newline at end of file diff --git a/model/model_selector.py b/model/model_selector.py index f74dde2..9afd0ff 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -1,107 +1,57 @@ -from model.DDGCRN.DDGCRN import DDGCRN -from model.HI import HI -from model.TWDGCN.TWDGCN import TWDGCN -from model.AGCRN.AGCRN import AGCRN -from model.NLT.HierAttnLstm import HierAttnLstm -from model.STGNCDE.Make_model import make_model -from model.DSANET.DSANET import DSANet -from model.STGCN.models import STGCNChebGraphConv -from model.DCRNN.dcrnn_model import DCRNNModel -from model.ARIMA.ARIMA import ARIMA -from model.TCN.TCN import TemporalConvNet -from model.GWN.GraphWaveNet import gwnet -from model.STFGNN.STFGNN import STFGNN -from model.STSGCN.STSGCN import STSGCN -from model.STGODE.STGODE import ODEGCN -from model.PDG2SEQ.PDG2Seqb import PDG2Seq -from model.STMLP.STMLP import STMLP -from model.STIDGCN.STIDGCN import STIDGCN -from model.STID.STID import STID -from model.STAEFormer.STAEFormer import STAEformer -from model.EXP.EXP32 import EXP as EXP -from model.MegaCRN.MegaCRNModel import MegaCRNModel -from model.ST_SSL.ST_SSL import STSSLModel -from model.STGNRDE.Make_model import make_model as make_nrde_model -from model.STAWnet.STAWnet import STAWnet -from model.REPST.repst import repst as REPST -from model.ASTRA.astra import ASTRA as ASTRA -from model.ASTRA.astrav2 import ASTRA as ASTRAv2 -from model.ASTRA.astrav3 import ASTRA as ASTRAv3 -from model.iTransformer.iTransformer import iTransformer -from model.Informer.model import Informer -from model.HI.HI import HI -from model.PatchTST.PatchTST import Model as PatchTST -from model.MTGNN.MTGNN import gtnet as MTGNN +import os +import json +import importlib +import sys +from pathlib import Path +class ModelRegistry: + def __init__(self): + self.models = {} + self.model_configs = {} + self.model_dir = Path(__file__).parent + self._load_model_configs() + + def _load_model_configs(self): + """加载所有model_config.json文件""" + # 直接遍历所有model_config.json文件 + for config_path in self.model_dir.rglob("model_config.json"): + # 读取配置文件 + with open(config_path, 'r') as f: + configs = json.load(f) + + # 处理每个模型配置 + for config in configs: + model_name = config["name"] + # 检查模型名冲突 + assert model_name not in self.model_configs, f"模型名冲突: {model_name} 已存在,冲突文件: {config_path}" + self.model_configs[model_name] = config + + def _load_model(self, model_name): + """动态加载模型""" + if model_name not in self.model_configs: + raise ValueError(f"模型 {model_name} 未注册") + + config = self.model_configs[model_name] + module = importlib.import_module(config["module"]) + model_cls = getattr(module, config["entry"]) + self.models[model_name] = model_cls + + def get_model(self, model_name): + """获取模型类或函数""" + if model_name not in self.models: + self._load_model(model_name) + return self.models[model_name] +# 初始化模型注册表 +model_registry = ModelRegistry() def model_selector(config): model_name = config["basic"]["model"] model_config = config["model"] - match model_name: - case "DDGCRN": - return DDGCRN(model_config) - case "TWDGCN": - return TWDGCN(model_config) - case "AGCRN": - return AGCRN(model_config) - case "NLT": - return HierAttnLstm(model_config) - case "STGNCDE": - return make_model(model_config) - case "DSANET": - return DSANet(model_config) - case "STGCN": - return STGCNChebGraphConv(model_config) - case "DCRNN": - return DCRNNModel(model_config) - case "ARIMA": - return ARIMA(model_config) - case "TCN": - return TemporalConvNet(model_config) - case "GWN": - return gwnet(model_config) - case "STFGNN": - return STFGNN(model_config) - case "STSGCN": - return STSGCN(model_config) - case "STGODE": - return ODEGCN(model_config) - case "PDG2SEQ": - return PDG2Seq(model_config) - case "STMLP": - return STMLP(model_config) - case "STIDGCN": - return STIDGCN(model_config) - case "STID": - return STID(model_config) - case "STAEFormer": - return STAEformer(model_config) - case "EXP": - return EXP(model_config) - case "MegaCRN": - return MegaCRNModel(model_config) - case "ST_SSL": - return STSSLModel(model_config) - case "STGNRDE": - return make_nrde_model(model_config) - case "STAWnet": - return STAWnet(model_config) - case "REPST": - return REPST(model_config) - case "ASTRA": - return ASTRA(model_config) - case "ASTRA_v2": - return ASTRAv2(model_config) - case "ASTRA_v3": - return ASTRAv3(model_config) - case "iTransformer": - return iTransformer(model_config) - case "Informer": - return Informer(model_config) - case "HI": - return HI(model_config) - case "PatchTST": - return PatchTST(model_config) - case "MTGNN": - return MTGNN(model_config) + + model_cls = model_registry.get_model(model_name) + model = model_cls(model_config) + # print(f"\n=== 模型选择结果 ===") + print(f"选择的模型: {model_name}") + print(f"模型入口: {model_registry.model_configs[model_name]['module']}:{model_registry.model_configs[model_name]['entry']}") + return model diff --git a/train.py b/train.py index 9c81209..5beb472 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ def read_config(config_path): # 全局配置 device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 1 + epochs = 100 # 拷贝项 config["basic"]["device"] = device @@ -63,14 +63,14 @@ def run(config): if __name__ == "__main__": # 指定模型 - model_list = ["Informer"] + model_list = ["iTransformer"] # 指定数据集 dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] # dataset_list = ["PEMS-BAY"] # 我的调试开关,不做测试就填 str(False) # os.environ["TRY"] = str(False) - os.environ["TRY"] = str(True) + os.environ["TRY"] = str(False) for model in model_list: for dataset in dataset_list: