diff --git a/config/HI/BJTaxi-Outflow.yaml b/config/HI/BJTaxi-OutFlow.yaml similarity index 100% rename from config/HI/BJTaxi-Outflow.yaml rename to config/HI/BJTaxi-OutFlow.yaml diff --git a/config/HI/SolarEnergy.yaml b/config/HI/SolarEnergy.yaml index 1558d3a..aa07cf6 100644 --- a/config/HI/SolarEnergy.yaml +++ b/config/HI/SolarEnergy.yaml @@ -6,11 +6,11 @@ basic: seed: 2023 data: - batch_size: 512 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 - input_dim: 6 + input_dim: 1 lag: 24 normalizer: std num_nodes: 137 @@ -25,7 +25,7 @@ model: train: - batch_size: 512 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/MTGNN/BJTaxi-Inflow.yaml b/config/MTGNN/BJTaxi-InFlow.yaml similarity index 100% rename from config/MTGNN/BJTaxi-Inflow.yaml rename to config/MTGNN/BJTaxi-InFlow.yaml diff --git a/config/MTGNN/BJTaxi-Outflow.yaml b/config/MTGNN/BJTaxi-OutFlow.yaml similarity index 100% rename from config/MTGNN/BJTaxi-Outflow.yaml rename to config/MTGNN/BJTaxi-OutFlow.yaml diff --git a/config/PatchTST/BJTaxi-Inflow.yaml b/config/PatchTST/BJTaxi-InFlow.yaml similarity index 100% rename from config/PatchTST/BJTaxi-Inflow.yaml rename to config/PatchTST/BJTaxi-InFlow.yaml diff --git a/config/PatchTST/BJTaxi-Outflow.yaml b/config/PatchTST/BJTaxi-OutFlow.yaml similarity index 100% rename from config/PatchTST/BJTaxi-Outflow.yaml rename to config/PatchTST/BJTaxi-OutFlow.yaml diff --git a/config/PatchTST/NYCBike-Inflow.yaml b/config/PatchTST/NYCBike-InFlow.yaml similarity index 100% rename from config/PatchTST/NYCBike-Inflow.yaml rename to config/PatchTST/NYCBike-InFlow.yaml diff --git a/config/PatchTST/NYCBike-Outflow.yaml b/config/PatchTST/NYCBike-OutFlow.yaml similarity index 100% rename from config/PatchTST/NYCBike-Outflow.yaml rename to config/PatchTST/NYCBike-OutFlow.yaml diff --git a/config/PatchTST/SolarEnergy.yaml b/config/PatchTST/SolarEnergy.yaml index d31a458..b6ca055 100644 --- a/config/PatchTST/SolarEnergy.yaml +++ b/config/PatchTST/SolarEnergy.yaml @@ -10,7 +10,7 @@ data: column_wise: false days_per_week: 7 horizon: 24 - input_dim: 6 + input_dim: 1 lag: 24 normalizer: std num_nodes: 137 diff --git a/config/REPST/BJTaxi-InFlow.yaml b/config/REPST/BJTaxi-InFlow.yaml old mode 100644 new mode 100755 index 3191eba..e8a17fc --- a/config/REPST/BJTaxi-InFlow.yaml +++ b/config/REPST/BJTaxi-InFlow.yaml @@ -1,6 +1,6 @@ basic: dataset: BJTaxi-InFlow - device: cuda:1 + device: cuda:0 mode: train model: REPST seed: 2023 @@ -27,7 +27,6 @@ model: input_dim: 1 n_heads: 1 num_nodes: 1024 - output_dim: 1 patch_len: 6 pred_len: 24 seq_len: 24 @@ -41,7 +40,7 @@ train: early_stop_patience: 15 epochs: 100 grad_norm: false - log_step: 1000 + log_step: 100 loss_func: mae lr_decay: true lr_decay_rate: 0.3 diff --git a/config/REPST/BJTaxi-Inflow.yaml b/config/REPST/BJTaxi-Inflow.yaml deleted file mode 100755 index e8a17fc..0000000 --- a/config/REPST/BJTaxi-Inflow.yaml +++ /dev/null @@ -1,55 +0,0 @@ -basic: - dataset: BJTaxi-InFlow - device: cuda:0 - mode: train - model: REPST - seed: 2023 - -data: - batch_size: 16 - column_wise: false - days_per_week: 7 - horizon: 24 - input_dim: 1 - lag: 24 - normalizer: std - num_nodes: 1024 - steps_per_day: 48 - test_ratio: 0.2 - val_ratio: 0.2 - -model: - d_ff: 128 - d_model: 64 - dropout: 0.2 - gpt_layers: 9 - gpt_path: ./GPT-2 - input_dim: 1 - n_heads: 1 - num_nodes: 1024 - patch_len: 6 - pred_len: 24 - seq_len: 24 - stride: 7 - word_num: 1000 - -train: - batch_size: 16 - debug: false - early_stop: true - early_stop_patience: 15 - epochs: 100 - grad_norm: false - log_step: 100 - loss_func: mae - lr_decay: true - lr_decay_rate: 0.3 - lr_decay_step: 5,20,40,70 - lr_init: 0.003 - mae_thresh: None - mape_thresh: 0.001 - max_grad_norm: 5 - output_dim: 1 - plot: false - real_value: true - weight_decay: 0 diff --git a/config/STID/BJTaxi_Inflow.yaml b/config/STID/BJTaxi_InFlow.yaml similarity index 100% rename from config/STID/BJTaxi_Inflow.yaml rename to config/STID/BJTaxi_InFlow.yaml diff --git a/config/STID/BJTaxi_Outflow.yaml b/config/STID/BJTaxi_OutFlow.yaml similarity index 100% rename from config/STID/BJTaxi_Outflow.yaml rename to config/STID/BJTaxi_OutFlow.yaml diff --git a/config/STID/NYCBike_Inflow.yaml b/config/STID/NYCBike_InFlow.yaml similarity index 100% rename from config/STID/NYCBike_Inflow.yaml rename to config/STID/NYCBike_InFlow.yaml diff --git a/config/STID/NYCBike_Outflow.yaml b/config/STID/NYCBike_OutFlow.yaml similarity index 100% rename from config/STID/NYCBike_Outflow.yaml rename to config/STID/NYCBike_OutFlow.yaml diff --git a/train.py b/train.py index b0b5af1..fcaaa6a 100644 --- a/train.py +++ b/train.py @@ -89,10 +89,10 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 - # model_list = ["iTransformer", "PatchTST", "HI"] + model_list = ["iTransformer", "PatchTST", "HI"] # model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] - model_list = ["iTransformer"] - dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] + # model_list = ["iTransformer"] + # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"] - # dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] + dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] main(model_list, dataset_list, debug = True) \ No newline at end of file