diff --git a/config/FPT/BJTaxi-InFlow.yaml b/config/FPT/BJTaxi-InFlow.yaml index 18abb67..72b6dbc 100644 --- a/config/FPT/BJTaxi-InFlow.yaml +++ b/config/FPT/BJTaxi-InFlow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: stride: 7 train: - batch_size: 32 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/FPT/BJTaxi-OutFlow.yaml b/config/FPT/BJTaxi-OutFlow.yaml index 3e6765a..b60a145 100644 --- a/config/FPT/BJTaxi-OutFlow.yaml +++ b/config/FPT/BJTaxi-OutFlow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: stride: 7 train: - batch_size: 32 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/REPST/BJTaxi-InFlow.yaml b/config/REPST/BJTaxi-InFlow.yaml index 8bdd348..5ded14f 100755 --- a/config/REPST/BJTaxi-InFlow.yaml +++ b/config/REPST/BJTaxi-InFlow.yaml @@ -28,6 +28,7 @@ model: output_dim: 1 n_heads: 1 num_nodes: 1024 + output_dim: 1 patch_len: 6 pred_len: 24 seq_len: 24 diff --git a/config/STID/AirQuality.yaml b/config/STID/AirQuality.yaml index b480b4f..46098c6 100755 --- a/config/STID/AirQuality.yaml +++ b/config/STID/AirQuality.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/BJTaxi-InFlow.yaml b/config/STID/BJTaxi-InFlow.yaml index 59e9501..b3e7b87 100644 --- a/config/STID/BJTaxi-InFlow.yaml +++ b/config/STID/BJTaxi-InFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/BJTaxi-OutFlow.yaml b/config/STID/BJTaxi-OutFlow.yaml index e2fdf43..822f74c 100644 --- a/config/STID/BJTaxi-OutFlow.yaml +++ b/config/STID/BJTaxi-OutFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/BJTaxi_InFlow.yaml b/config/STID/BJTaxi_InFlow.yaml index d50ba22..d12df1b 100755 --- a/config/STID/BJTaxi_InFlow.yaml +++ b/config/STID/BJTaxi_InFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/BJTaxi_OutFlow.yaml b/config/STID/BJTaxi_OutFlow.yaml index e2fdf43..822f74c 100755 --- a/config/STID/BJTaxi_OutFlow.yaml +++ b/config/STID/BJTaxi_OutFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/METR-LA.yaml b/config/STID/METR-LA.yaml index 7ceb4f0..7ab5199 100755 --- a/config/STID/METR-LA.yaml +++ b/config/STID/METR-LA.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 300 diff --git a/config/STID/NYCBike-InFlow.yaml b/config/STID/NYCBike-InFlow.yaml index e509007..324a491 100644 --- a/config/STID/NYCBike-InFlow.yaml +++ b/config/STID/NYCBike-InFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/NYCBike-OutFlow.yaml b/config/STID/NYCBike-OutFlow.yaml index 155baf3..c77ac79 100644 --- a/config/STID/NYCBike-OutFlow.yaml +++ b/config/STID/NYCBike-OutFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/NYCBike_InFlow.yaml b/config/STID/NYCBike_InFlow.yaml index e509007..324a491 100755 --- a/config/STID/NYCBike_InFlow.yaml +++ b/config/STID/NYCBike_InFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/NYCBike_OutFlow.yaml b/config/STID/NYCBike_OutFlow.yaml index 155baf3..c77ac79 100755 --- a/config/STID/NYCBike_OutFlow.yaml +++ b/config/STID/NYCBike_OutFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/PEMS-BAY.yaml b/config/STID/PEMS-BAY.yaml index 561102d..876a502 100755 --- a/config/STID/PEMS-BAY.yaml +++ b/config/STID/PEMS-BAY.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 300 diff --git a/config/STID/SolarEnergy.yaml b/config/STID/SolarEnergy.yaml index 0d787c9..693e371 100755 --- a/config/STID/SolarEnergy.yaml +++ b/config/STID/SolarEnergy.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/train.py b/train.py index b489cde..f8c939e 100644 --- a/train.py +++ b/train.py @@ -90,9 +90,9 @@ def read_config(config_path): config = yaml.safe_load(file) # 全局配置 - device = "cuda:0" # 指定设备为cuda:0 + device = "cuda:1" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 10 # 训练轮数 + epochs = 100 # 训练轮数 # 拷贝项 config["basic"]["seed"] = seed @@ -121,5 +121,5 @@ if __name__ == "__main__": all_dataset = big_dataset + mid_dataset + regular_dataset - dataset_list = regular_dataset - main(model_list, dataset_list, debug=True) + dataset_list = all_dataset + main(model_list, dataset_list, debug=False)