From b97111f5ea1db1e2c14f3cefaef438b6c2868815 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 16 Dec 2025 17:38:37 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=9F=E4=B8=80REPST:AirQuality=E9=85=8D?= =?UTF-8?q?=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/REPST/AirQuality.yaml | 2 +- train.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/config/REPST/AirQuality.yaml b/config/REPST/AirQuality.yaml index c035a44..ee690f1 100755 --- a/config/REPST/AirQuality.yaml +++ b/config/REPST/AirQuality.yaml @@ -50,7 +50,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 3 + output_dim: 6 plot: false real_value: true weight_decay: 0 diff --git a/train.py b/train.py index b5b42b5..da6c058 100644 --- a/train.py +++ b/train.py @@ -89,8 +89,8 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 - model_list = ["iTransformer", "PatchTST", "HI"] - # model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] + # model_list = ["iTransformer", "PatchTST", "HI"] + model_list = ["ASTRA_v3", "ASTRA_v2", "ASTRA", "REPST", "STAEFormer", "MTGNN", "iTransformer", "PatchTST", "HI"] # model_list = ["MTGNN"] # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"]