From 521d72587f54594dc51345030ba35072c44389ac Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 6 Jan 2026 10:21:22 +0800 Subject: [PATCH] fix repst config --- config/REPST/BJTaxi-InFlow.yaml | 1 + config/REPST/BJTaxi-outflow.yaml | 55 ------------------------------- config/REPST/METR-LA.yaml | 1 + config/REPST/NYCBike-inflow.yaml | 55 ------------------------------- config/REPST/NYCBike-outflow.yaml | 55 ------------------------------- config/REPST/PEMS-BAY.yaml | 1 + config/REPST/SolarEnergy.yaml | 1 + train.py | 2 +- 8 files changed, 5 insertions(+), 166 deletions(-) delete mode 100755 config/REPST/BJTaxi-outflow.yaml delete mode 100755 config/REPST/NYCBike-inflow.yaml delete mode 100755 config/REPST/NYCBike-outflow.yaml diff --git a/config/REPST/BJTaxi-InFlow.yaml b/config/REPST/BJTaxi-InFlow.yaml index e8a17fc..8bdd348 100755 --- a/config/REPST/BJTaxi-InFlow.yaml +++ b/config/REPST/BJTaxi-InFlow.yaml @@ -25,6 +25,7 @@ model: gpt_layers: 9 gpt_path: ./GPT-2 input_dim: 1 + output_dim: 1 n_heads: 1 num_nodes: 1024 patch_len: 6 diff --git a/config/REPST/BJTaxi-outflow.yaml b/config/REPST/BJTaxi-outflow.yaml deleted file mode 100755 index bd5fbb8..0000000 --- a/config/REPST/BJTaxi-outflow.yaml +++ /dev/null @@ -1,55 +0,0 @@ -basic: - dataset: BJTaxi-OutFlow - device: cuda:0 - mode: train - model: REPST - seed: 2023 - -data: - batch_size: 16 - column_wise: false - days_per_week: 7 - horizon: 24 - input_dim: 1 - lag: 24 - normalizer: std - num_nodes: 1024 - steps_per_day: 48 - test_ratio: 0.2 - val_ratio: 0.2 - -model: - d_ff: 128 - d_model: 64 - dropout: 0.2 - gpt_layers: 9 - gpt_path: ./GPT-2 - input_dim: 1 - n_heads: 1 - num_nodes: 1024 - patch_len: 6 - pred_len: 24 - seq_len: 24 - stride: 7 - word_num: 1000 - -train: - batch_size: 16 - debug: false - early_stop: true - early_stop_patience: 15 - epochs: 100 - grad_norm: false - log_step: 100 - loss_func: mae - lr_decay: true - lr_decay_rate: 0.3 - lr_decay_step: 5,20,40,70 - lr_init: 0.003 - mae_thresh: None - mape_thresh: 0.001 - max_grad_norm: 5 - output_dim: 1 - plot: false - real_value: true - weight_decay: 0 diff --git a/config/REPST/METR-LA.yaml b/config/REPST/METR-LA.yaml index d2fbdf9..db6a6d0 100755 --- a/config/REPST/METR-LA.yaml +++ b/config/REPST/METR-LA.yaml @@ -25,6 +25,7 @@ model: gpt_layers: 9 gpt_path: ./GPT-2 input_dim: 1 + output_dim: 1 n_heads: 1 num_nodes: 207 patch_len: 6 diff --git a/config/REPST/NYCBike-inflow.yaml b/config/REPST/NYCBike-inflow.yaml deleted file mode 100755 index e59e3cf..0000000 --- a/config/REPST/NYCBike-inflow.yaml +++ /dev/null @@ -1,55 +0,0 @@ -basic: - dataset: NYCBike-InFlow - device: cuda:0 - mode: train - model: REPST - seed: 2023 - -data: - batch_size: 16 - column_wise: false - days_per_week: 7 - horizon: 24 - input_dim: 1 - lag: 24 - normalizer: std - num_nodes: 128 - steps_per_day: 48 - test_ratio: 0.2 - val_ratio: 0.2 - -model: - d_ff: 128 - d_model: 64 - dropout: 0.2 - gpt_layers: 9 - gpt_path: ./GPT-2 - input_dim: 1 - n_heads: 1 - num_nodes: 128 - patch_len: 6 - pred_len: 24 - seq_len: 24 - stride: 7 - word_num: 1000 - -train: - batch_size: 16 - debug: false - early_stop: true - early_stop_patience: 15 - epochs: 100 - grad_norm: false - log_step: 100 - loss_func: mae - lr_decay: true - lr_decay_rate: 0.3 - lr_decay_step: 5,20,40,70 - lr_init: 0.003 - mae_thresh: None - mape_thresh: 0.001 - max_grad_norm: 5 - output_dim: 1 - plot: false - real_value: true - weight_decay: 0 diff --git a/config/REPST/NYCBike-outflow.yaml b/config/REPST/NYCBike-outflow.yaml deleted file mode 100755 index 59a4389..0000000 --- a/config/REPST/NYCBike-outflow.yaml +++ /dev/null @@ -1,55 +0,0 @@ -basic: - dataset: NYCBike-OutFlow - device: cuda:0 - mode: train - model: REPST - seed: 2023 - -data: - batch_size: 16 - column_wise: false - days_per_week: 7 - horizon: 24 - input_dim: 1 - lag: 24 - normalizer: std - num_nodes: 128 - steps_per_day: 48 - test_ratio: 0.2 - val_ratio: 0.2 - -model: - d_ff: 128 - d_model: 64 - dropout: 0.2 - gpt_layers: 9 - gpt_path: ./GPT-2 - input_dim: 1 - n_heads: 1 - num_nodes: 128 - patch_len: 6 - pred_len: 24 - seq_len: 24 - stride: 7 - word_num: 1000 - -train: - batch_size: 16 - debug: false - early_stop: true - early_stop_patience: 15 - epochs: 100 - grad_norm: false - log_step: 100 - loss_func: mae - lr_decay: true - lr_decay_rate: 0.3 - lr_decay_step: 5,20,40,70 - lr_init: 0.003 - mae_thresh: None - mape_thresh: 0.001 - max_grad_norm: 5 - output_dim: 1 - plot: false - real_value: true - weight_decay: 0 diff --git a/config/REPST/PEMS-BAY.yaml b/config/REPST/PEMS-BAY.yaml index e1dcf0b..96e1a7c 100755 --- a/config/REPST/PEMS-BAY.yaml +++ b/config/REPST/PEMS-BAY.yaml @@ -25,6 +25,7 @@ model: gpt_layers: 9 gpt_path: ./GPT-2 input_dim: 1 + output_dim: 1 n_heads: 1 num_nodes: 325 patch_len: 6 diff --git a/config/REPST/SolarEnergy.yaml b/config/REPST/SolarEnergy.yaml index a96e58a..ef42011 100755 --- a/config/REPST/SolarEnergy.yaml +++ b/config/REPST/SolarEnergy.yaml @@ -25,6 +25,7 @@ model: gpt_layers: 9 gpt_path: ./GPT-2 input_dim: 1 + output_dim: 1 n_heads: 1 num_nodes: 137 patch_len: 6 diff --git a/train.py b/train.py index ecf1f01..b489cde 100644 --- a/train.py +++ b/train.py @@ -110,7 +110,7 @@ def read_config(config_path): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["STID"] + model_list = ["REPST"] # model_list = ["PatchTST"] air = ["AirQuality"]