diff --git a/config/FPT/BJTaxi-InFlow.yaml b/config/FPT/BJTaxi-InFlow.yaml index 18abb67..72b6dbc 100644 --- a/config/FPT/BJTaxi-InFlow.yaml +++ b/config/FPT/BJTaxi-InFlow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: stride: 7 train: - batch_size: 32 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/FPT/BJTaxi-OutFlow.yaml b/config/FPT/BJTaxi-OutFlow.yaml index 3e6765a..b60a145 100644 --- a/config/FPT/BJTaxi-OutFlow.yaml +++ b/config/FPT/BJTaxi-OutFlow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: stride: 7 train: - batch_size: 32 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/REPST/BJTaxi-InFlow.yaml b/config/REPST/BJTaxi-InFlow.yaml index e8a17fc..5a28595 100755 --- a/config/REPST/BJTaxi-InFlow.yaml +++ b/config/REPST/BJTaxi-InFlow.yaml @@ -27,6 +27,7 @@ model: input_dim: 1 n_heads: 1 num_nodes: 1024 + output_dim: 1 patch_len: 6 pred_len: 24 seq_len: 24 diff --git a/train.py b/train.py index 3dba2c9..f8c939e 100644 --- a/train.py +++ b/train.py @@ -110,7 +110,7 @@ def read_config(config_path): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["STID"] + model_list = ["REPST"] # model_list = ["PatchTST"] air = ["AirQuality"]