diff --git a/config/Informer/PEMS-BAY.yaml b/config/Informer/PEMS-BAY.yaml index 3ab5584..276fb4a 100644 --- a/config/Informer/PEMS-BAY.yaml +++ b/config/Informer/PEMS-BAY.yaml @@ -22,7 +22,7 @@ data: model: activation: gelu seq_len: 24 - label_len: 12 + label_len: 24 pred_len: 24 d_model: 128 d_ff: 2048 @@ -59,7 +59,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 6 + output_dim: 1 plot: false pred_len: 24 real_value: true diff --git a/config/Informer/SolarEnergy.yaml b/config/Informer/SolarEnergy.yaml index a1d5ea2..570c595 100644 --- a/config/Informer/SolarEnergy.yaml +++ b/config/Informer/SolarEnergy.yaml @@ -6,19 +6,18 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 - input_dim: 6 + input_dim: 1 lag: 24 - label_len: 24 normalizer: std num_nodes: 137 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 - + model: activation: gelu seq_len: 24 @@ -43,7 +42,7 @@ model: train: - batch_size: 16 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/train.py b/train.py index 19ec820..97f9bc8 100644 --- a/train.py +++ b/train.py @@ -12,7 +12,7 @@ def read_config(config_path): config = yaml.safe_load(file) # 全局配置 - device = "cuda:0" # 指定设备为cuda:0 + device = "cpu" # 指定设备为cuda:0 seed = 2023 # 随机种子 epochs = 1 # 训练轮数 @@ -102,10 +102,11 @@ def main(model_list, data, debug=False): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["Informer"] + model_list = ["iTransformer", "Informer"] # model_list = ["PatchTST"] # dataset_list = ["AirQuality"] + dataset_list = ["SolarEnergy"] # dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"] - dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"] - # dataset_list = ["METR-LA"] + # dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"] + # dataset_list = ["BJTaxi-OutFlow"] main(model_list, dataset_list, debug=True)