diff --git a/config/REPST/AirQuality.yaml b/config/REPST/AirQuality.yaml index c035a44..ee690f1 100755 --- a/config/REPST/AirQuality.yaml +++ b/config/REPST/AirQuality.yaml @@ -50,7 +50,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 3 + output_dim: 6 plot: false real_value: true weight_decay: 0 diff --git a/train.py b/train.py index b5b42b5..da6c058 100644 --- a/train.py +++ b/train.py @@ -89,8 +89,8 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 - model_list = ["iTransformer", "PatchTST", "HI"] - # model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] + # model_list = ["iTransformer", "PatchTST", "HI"] + model_list = ["ASTRA_v3", "ASTRA_v2", "ASTRA", "REPST", "STAEFormer", "MTGNN", "iTransformer", "PatchTST", "HI"] # model_list = ["MTGNN"] # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"]