diff --git a/config/STID/AirQuality.yaml b/config/STID/AirQuality.yaml index b480b4f..46098c6 100755 --- a/config/STID/AirQuality.yaml +++ b/config/STID/AirQuality.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/BJTaxi-InFlow.yaml b/config/STID/BJTaxi-InFlow.yaml index 59e9501..b3e7b87 100644 --- a/config/STID/BJTaxi-InFlow.yaml +++ b/config/STID/BJTaxi-InFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/BJTaxi-OutFlow.yaml b/config/STID/BJTaxi-OutFlow.yaml index e2fdf43..822f74c 100644 --- a/config/STID/BJTaxi-OutFlow.yaml +++ b/config/STID/BJTaxi-OutFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/BJTaxi_InFlow.yaml b/config/STID/BJTaxi_InFlow.yaml index d50ba22..d12df1b 100755 --- a/config/STID/BJTaxi_InFlow.yaml +++ b/config/STID/BJTaxi_InFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/BJTaxi_OutFlow.yaml b/config/STID/BJTaxi_OutFlow.yaml index e2fdf43..822f74c 100755 --- a/config/STID/BJTaxi_OutFlow.yaml +++ b/config/STID/BJTaxi_OutFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/METR-LA.yaml b/config/STID/METR-LA.yaml index 7ceb4f0..7ab5199 100755 --- a/config/STID/METR-LA.yaml +++ b/config/STID/METR-LA.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 300 diff --git a/config/STID/NYCBike-InFlow.yaml b/config/STID/NYCBike-InFlow.yaml index e509007..324a491 100644 --- a/config/STID/NYCBike-InFlow.yaml +++ b/config/STID/NYCBike-InFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/NYCBike-OutFlow.yaml b/config/STID/NYCBike-OutFlow.yaml index 155baf3..c77ac79 100644 --- a/config/STID/NYCBike-OutFlow.yaml +++ b/config/STID/NYCBike-OutFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/NYCBike_InFlow.yaml b/config/STID/NYCBike_InFlow.yaml index e509007..324a491 100755 --- a/config/STID/NYCBike_InFlow.yaml +++ b/config/STID/NYCBike_InFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/NYCBike_OutFlow.yaml b/config/STID/NYCBike_OutFlow.yaml index 155baf3..c77ac79 100755 --- a/config/STID/NYCBike_OutFlow.yaml +++ b/config/STID/NYCBike_OutFlow.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/config/STID/PEMS-BAY.yaml b/config/STID/PEMS-BAY.yaml index 561102d..876a502 100755 --- a/config/STID/PEMS-BAY.yaml +++ b/config/STID/PEMS-BAY.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 300 diff --git a/config/STID/SolarEnergy.yaml b/config/STID/SolarEnergy.yaml index 0d787c9..693e371 100755 --- a/config/STID/SolarEnergy.yaml +++ b/config/STID/SolarEnergy.yaml @@ -40,7 +40,7 @@ model: train: batch_size: 64 - debug: true + debug: false early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/train.py b/train.py index ecf1f01..3dba2c9 100644 --- a/train.py +++ b/train.py @@ -90,9 +90,9 @@ def read_config(config_path): config = yaml.safe_load(file) # 全局配置 - device = "cuda:0" # 指定设备为cuda:0 + device = "cuda:1" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 10 # 训练轮数 + epochs = 100 # 训练轮数 # 拷贝项 config["basic"]["seed"] = seed @@ -121,5 +121,5 @@ if __name__ == "__main__": all_dataset = big_dataset + mid_dataset + regular_dataset - dataset_list = regular_dataset - main(model_list, dataset_list, debug=True) + dataset_list = all_dataset + main(model_list, dataset_list, debug=False)