From 659b41f6123a43b2513aabd325b35342df610f5f Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 15 Dec 2025 21:33:28 +0800 Subject: [PATCH] =?UTF-8?q?refactor(config):=20=E7=BB=9F=E4=B8=80=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E6=96=87=E4=BB=B6=E5=91=BD=E5=90=8D=E5=B9=B6=E8=B0=83?= =?UTF-8?q?=E6=95=B4=E6=A8=A1=E5=9E=8B=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 更新多个模型的配置文件命名格式,统一使用大驼峰格式 调整SolarEnergy和BJTaxi数据集的输入维度和批量大小 删除旧命名格式的配置文件并添加新的配置文件 修改训练脚本中的模型和数据集列表用于调试 --- ...JTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} | 0 config/HI/SolarEnergy.yaml | 6 +- ...{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} | 0 ...JTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} | 0 ...{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} | 0 ...JTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} | 0 ...YCBike-Inflow.yaml => NYCBike-InFlow.yaml} | 0 ...Bike-Outflow.yaml => NYCBike-OutFlow.yaml} | 0 config/PatchTST/SolarEnergy.yaml | 2 +- config/REPST/BJTaxi-InFlow.yaml | 5 +- config/REPST/BJTaxi-Inflow.yaml | 55 ------------------- ...{BJTaxi_Inflow.yaml => BJTaxi_InFlow.yaml} | 0 ...JTaxi_Outflow.yaml => BJTaxi_OutFlow.yaml} | 0 ...YCBike_Inflow.yaml => NYCBike_InFlow.yaml} | 0 ...Bike_Outflow.yaml => NYCBike_OutFlow.yaml} | 0 train.py | 8 +-- 16 files changed, 10 insertions(+), 66 deletions(-) rename config/HI/{BJTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} (100%) rename config/MTGNN/{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} (100%) rename config/MTGNN/{BJTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} (100%) rename config/PatchTST/{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} (100%) rename config/PatchTST/{BJTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} (100%) rename config/PatchTST/{NYCBike-Inflow.yaml => NYCBike-InFlow.yaml} (100%) rename config/PatchTST/{NYCBike-Outflow.yaml => NYCBike-OutFlow.yaml} (100%) mode change 100644 => 100755 config/REPST/BJTaxi-InFlow.yaml delete mode 100755 config/REPST/BJTaxi-Inflow.yaml rename config/STID/{BJTaxi_Inflow.yaml => BJTaxi_InFlow.yaml} (100%) rename config/STID/{BJTaxi_Outflow.yaml => BJTaxi_OutFlow.yaml} (100%) rename config/STID/{NYCBike_Inflow.yaml => NYCBike_InFlow.yaml} (100%) rename config/STID/{NYCBike_Outflow.yaml => NYCBike_OutFlow.yaml} (100%) diff --git a/config/HI/BJTaxi-Outflow.yaml b/config/HI/BJTaxi-OutFlow.yaml similarity index 100% rename from config/HI/BJTaxi-Outflow.yaml rename to config/HI/BJTaxi-OutFlow.yaml diff --git a/config/HI/SolarEnergy.yaml b/config/HI/SolarEnergy.yaml index 1558d3a..aa07cf6 100644 --- a/config/HI/SolarEnergy.yaml +++ b/config/HI/SolarEnergy.yaml @@ -6,11 +6,11 @@ basic: seed: 2023 data: - batch_size: 512 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 - input_dim: 6 + input_dim: 1 lag: 24 normalizer: std num_nodes: 137 @@ -25,7 +25,7 @@ model: train: - batch_size: 512 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/MTGNN/BJTaxi-Inflow.yaml b/config/MTGNN/BJTaxi-InFlow.yaml similarity index 100% rename from config/MTGNN/BJTaxi-Inflow.yaml rename to config/MTGNN/BJTaxi-InFlow.yaml diff --git a/config/MTGNN/BJTaxi-Outflow.yaml b/config/MTGNN/BJTaxi-OutFlow.yaml similarity index 100% rename from config/MTGNN/BJTaxi-Outflow.yaml rename to config/MTGNN/BJTaxi-OutFlow.yaml diff --git a/config/PatchTST/BJTaxi-Inflow.yaml b/config/PatchTST/BJTaxi-InFlow.yaml similarity index 100% rename from config/PatchTST/BJTaxi-Inflow.yaml rename to config/PatchTST/BJTaxi-InFlow.yaml diff --git a/config/PatchTST/BJTaxi-Outflow.yaml b/config/PatchTST/BJTaxi-OutFlow.yaml similarity index 100% rename from config/PatchTST/BJTaxi-Outflow.yaml rename to config/PatchTST/BJTaxi-OutFlow.yaml diff --git a/config/PatchTST/NYCBike-Inflow.yaml b/config/PatchTST/NYCBike-InFlow.yaml similarity index 100% rename from config/PatchTST/NYCBike-Inflow.yaml rename to config/PatchTST/NYCBike-InFlow.yaml diff --git a/config/PatchTST/NYCBike-Outflow.yaml b/config/PatchTST/NYCBike-OutFlow.yaml similarity index 100% rename from config/PatchTST/NYCBike-Outflow.yaml rename to config/PatchTST/NYCBike-OutFlow.yaml diff --git a/config/PatchTST/SolarEnergy.yaml b/config/PatchTST/SolarEnergy.yaml index d31a458..b6ca055 100644 --- a/config/PatchTST/SolarEnergy.yaml +++ b/config/PatchTST/SolarEnergy.yaml @@ -10,7 +10,7 @@ data: column_wise: false days_per_week: 7 horizon: 24 - input_dim: 6 + input_dim: 1 lag: 24 normalizer: std num_nodes: 137 diff --git a/config/REPST/BJTaxi-InFlow.yaml b/config/REPST/BJTaxi-InFlow.yaml old mode 100644 new mode 100755 index 3191eba..e8a17fc --- a/config/REPST/BJTaxi-InFlow.yaml +++ b/config/REPST/BJTaxi-InFlow.yaml @@ -1,6 +1,6 @@ basic: dataset: BJTaxi-InFlow - device: cuda:1 + device: cuda:0 mode: train model: REPST seed: 2023 @@ -27,7 +27,6 @@ model: input_dim: 1 n_heads: 1 num_nodes: 1024 - output_dim: 1 patch_len: 6 pred_len: 24 seq_len: 24 @@ -41,7 +40,7 @@ train: early_stop_patience: 15 epochs: 100 grad_norm: false - log_step: 1000 + log_step: 100 loss_func: mae lr_decay: true lr_decay_rate: 0.3 diff --git a/config/REPST/BJTaxi-Inflow.yaml b/config/REPST/BJTaxi-Inflow.yaml deleted file mode 100755 index e8a17fc..0000000 --- a/config/REPST/BJTaxi-Inflow.yaml +++ /dev/null @@ -1,55 +0,0 @@ -basic: - dataset: BJTaxi-InFlow - device: cuda:0 - mode: train - model: REPST - seed: 2023 - -data: - batch_size: 16 - column_wise: false - days_per_week: 7 - horizon: 24 - input_dim: 1 - lag: 24 - normalizer: std - num_nodes: 1024 - steps_per_day: 48 - test_ratio: 0.2 - val_ratio: 0.2 - -model: - d_ff: 128 - d_model: 64 - dropout: 0.2 - gpt_layers: 9 - gpt_path: ./GPT-2 - input_dim: 1 - n_heads: 1 - num_nodes: 1024 - patch_len: 6 - pred_len: 24 - seq_len: 24 - stride: 7 - word_num: 1000 - -train: - batch_size: 16 - debug: false - early_stop: true - early_stop_patience: 15 - epochs: 100 - grad_norm: false - log_step: 100 - loss_func: mae - lr_decay: true - lr_decay_rate: 0.3 - lr_decay_step: 5,20,40,70 - lr_init: 0.003 - mae_thresh: None - mape_thresh: 0.001 - max_grad_norm: 5 - output_dim: 1 - plot: false - real_value: true - weight_decay: 0 diff --git a/config/STID/BJTaxi_Inflow.yaml b/config/STID/BJTaxi_InFlow.yaml similarity index 100% rename from config/STID/BJTaxi_Inflow.yaml rename to config/STID/BJTaxi_InFlow.yaml diff --git a/config/STID/BJTaxi_Outflow.yaml b/config/STID/BJTaxi_OutFlow.yaml similarity index 100% rename from config/STID/BJTaxi_Outflow.yaml rename to config/STID/BJTaxi_OutFlow.yaml diff --git a/config/STID/NYCBike_Inflow.yaml b/config/STID/NYCBike_InFlow.yaml similarity index 100% rename from config/STID/NYCBike_Inflow.yaml rename to config/STID/NYCBike_InFlow.yaml diff --git a/config/STID/NYCBike_Outflow.yaml b/config/STID/NYCBike_OutFlow.yaml similarity index 100% rename from config/STID/NYCBike_Outflow.yaml rename to config/STID/NYCBike_OutFlow.yaml diff --git a/train.py b/train.py index b0b5af1..fcaaa6a 100644 --- a/train.py +++ b/train.py @@ -89,10 +89,10 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 - # model_list = ["iTransformer", "PatchTST", "HI"] + model_list = ["iTransformer", "PatchTST", "HI"] # model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] - model_list = ["iTransformer"] - dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] + # model_list = ["iTransformer"] + # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"] - # dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] + dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] main(model_list, dataset_list, debug = True) \ No newline at end of file