统一REPST:AirQuality配置

This commit is contained in:
czzhangheng 2025-12-16 17:38:37 +08:00
parent b38e4a5da2
commit b97111f5ea
2 changed files with 3 additions and 3 deletions

View File

@ -50,7 +50,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 3 output_dim: 6
plot: false plot: false
real_value: true real_value: true
weight_decay: 0 weight_decay: 0

View File

@ -89,8 +89,8 @@ def main(model, data, debug=False):
if __name__ == "__main__": if __name__ == "__main__":
# 调试用 # 调试用
model_list = ["iTransformer", "PatchTST", "HI"] # model_list = ["iTransformer", "PatchTST", "HI"]
# model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] model_list = ["ASTRA_v3", "ASTRA_v2", "ASTRA", "REPST", "STAEFormer", "MTGNN", "iTransformer", "PatchTST", "HI"]
# model_list = ["MTGNN"] # model_list = ["MTGNN"]
# dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"]
# dataset_list = ["AirQuality"] # dataset_list = ["AirQuality"]