From 2bfa444f8eab1e00f4ebad0dd0d3c03a6abd2d20 Mon Sep 17 00:00:00 2001 From: "haoyu.he" Date: Sun, 4 Jan 2026 14:00:15 +0800 Subject: [PATCH 1/2] stid config --- config/STID/AirQuality.yaml | 2 +- config/STID/BJTaxi-InFlow.yaml | 2 +- config/STID/BJTaxi-OutFlow.yaml | 2 +- config/STID/BJTaxi_InFlow.yaml | 2 +- config/STID/BJTaxi_OutFlow.yaml | 2 +- config/STID/METR-LA.yaml | 2 +- config/STID/NYCBike-InFlow.yaml | 2 +- config/STID/NYCBike-OutFlow.yaml | 2 +- config/STID/NYCBike_InFlow.yaml | 2 +- config/STID/NYCBike_OutFlow.yaml | 2 +- config/STID/PEMS-BAY.yaml | 2 +- config/STID/SolarEnergy.yaml | 2 +- train.py | 8 ++++---- 13 files changed, 16 insertions(+), 16 deletions(-) diff --git a/config/STID/AirQuality.yaml b/config/STID/AirQuality.yaml index b480b4f..46098c6 100755 --- a/config/STID/AirQuality.yaml +++ b/config/STID/AirQuality.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/BJTaxi-InFlow.yaml b/config/STID/BJTaxi-InFlow.yaml index 59e9501..b3e7b87 100644 --- a/config/STID/BJTaxi-InFlow.yaml +++ b/config/STID/BJTaxi-InFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/BJTaxi-OutFlow.yaml b/config/STID/BJTaxi-OutFlow.yaml index e2fdf43..822f74c 100644 --- a/config/STID/BJTaxi-OutFlow.yaml +++ b/config/STID/BJTaxi-OutFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/BJTaxi_InFlow.yaml b/config/STID/BJTaxi_InFlow.yaml index d50ba22..d12df1b 100755 --- a/config/STID/BJTaxi_InFlow.yaml +++ b/config/STID/BJTaxi_InFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/BJTaxi_OutFlow.yaml b/config/STID/BJTaxi_OutFlow.yaml index e2fdf43..822f74c 100755 --- a/config/STID/BJTaxi_OutFlow.yaml +++ b/config/STID/BJTaxi_OutFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/METR-LA.yaml b/config/STID/METR-LA.yaml index 7ceb4f0..7ab5199 100755 --- a/config/STID/METR-LA.yaml +++ b/config/STID/METR-LA.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 300 diff --git a/config/STID/NYCBike-InFlow.yaml b/config/STID/NYCBike-InFlow.yaml index e509007..324a491 100644 --- a/config/STID/NYCBike-InFlow.yaml +++ b/config/STID/NYCBike-InFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/NYCBike-OutFlow.yaml b/config/STID/NYCBike-OutFlow.yaml index 155baf3..c77ac79 100644 --- a/config/STID/NYCBike-OutFlow.yaml +++ b/config/STID/NYCBike-OutFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/NYCBike_InFlow.yaml b/config/STID/NYCBike_InFlow.yaml index e509007..324a491 100755 --- a/config/STID/NYCBike_InFlow.yaml +++ b/config/STID/NYCBike_InFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/NYCBike_OutFlow.yaml b/config/STID/NYCBike_OutFlow.yaml index 155baf3..c77ac79 100755 --- a/config/STID/NYCBike_OutFlow.yaml +++ b/config/STID/NYCBike_OutFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/PEMS-BAY.yaml b/config/STID/PEMS-BAY.yaml index 561102d..876a502 100755 --- a/config/STID/PEMS-BAY.yaml +++ b/config/STID/PEMS-BAY.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 300 diff --git a/config/STID/SolarEnergy.yaml b/config/STID/SolarEnergy.yaml index 0d787c9..693e371 100755 --- a/config/STID/SolarEnergy.yaml +++ b/config/STID/SolarEnergy.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/train.py b/train.py index ecf1f01..3dba2c9 100644 --- a/train.py +++ b/train.py @@ -90,9 +90,9 @@ def read_config(config_path): config = yaml.safe_load(file) # 全局配置 - device = "cuda:0" # 指定设备为cuda:0 + device = "cuda:1" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 10 # 训练轮数 + epochs = 100 # 训练轮数 # 拷贝项 config["basic"]["seed"] = seed @@ -121,5 +121,5 @@ if __name__ == "__main__": all_dataset = big_dataset + mid_dataset + regular_dataset - dataset_list = regular_dataset - main(model_list, dataset_list, debug=True) + dataset_list = all_dataset + main(model_list, dataset_list, debug=False) From 2c54d81a67b74419ffaab660e637e504636a19a7 Mon Sep 17 00:00:00 2001 From: meowhe Date: Mon, 5 Jan 2026 14:18:54 +0800 Subject: [PATCH 2/2] bug fixes --- config/FPT/BJTaxi-InFlow.yaml | 4 ++-- config/FPT/BJTaxi-OutFlow.yaml | 4 ++-- config/REPST/BJTaxi-InFlow.yaml | 1 + train.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/config/FPT/BJTaxi-InFlow.yaml b/config/FPT/BJTaxi-InFlow.yaml index 18abb67..72b6dbc 100644 --- a/config/FPT/BJTaxi-InFlow.yaml +++ b/config/FPT/BJTaxi-InFlow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: stride: 7 train: - batch_size: 32 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/FPT/BJTaxi-OutFlow.yaml b/config/FPT/BJTaxi-OutFlow.yaml index 3e6765a..b60a145 100644 --- a/config/FPT/BJTaxi-OutFlow.yaml +++ b/config/FPT/BJTaxi-OutFlow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: stride: 7 train: - batch_size: 32 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/REPST/BJTaxi-InFlow.yaml b/config/REPST/BJTaxi-InFlow.yaml index e8a17fc..5a28595 100755 --- a/config/REPST/BJTaxi-InFlow.yaml +++ b/config/REPST/BJTaxi-InFlow.yaml @@ -27,6 +27,7 @@ model: input_dim: 1 n_heads: 1 num_nodes: 1024 + output_dim: 1 patch_len: 6 pred_len: 24 seq_len: 24 diff --git a/train.py b/train.py index 3dba2c9..f8c939e 100644 --- a/train.py +++ b/train.py @@ -110,7 +110,7 @@ def read_config(config_path): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["STID"] + model_list = ["REPST"] # model_list = ["PatchTST"] air = ["AirQuality"]