From 9c50c30918cc7778f98e038657049c7a7f9caa28 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 24 Nov 2025 21:50:24 +0800 Subject: [PATCH] AEPSA v0.1 --- .vscode/launch.json | 110 ++++++-- config/AEPSA/AirQuality.yaml | 58 ++++ .../AEPSA/{PEMSD8.yaml => BJTaxi-Inflow.yaml} | 18 +- .../PEMSD8.yaml => AEPSA/BJTaxi-outflow.yaml} | 20 +- config/AEPSA/METR-LA.yaml | 59 ++++ config/AEPSA/NYCBike-inflow.yaml | 58 ++++ config/AEPSA/NYCBike-outflow.yaml | 58 ++++ config/AEPSA/PEMS-BAY.yaml | 4 +- config/AEPSA/SolarEnergy.yaml | 59 ++++ config/REPST/AirQuality.yaml | 1 + config/REPST/BJTaxi-Inflow.yaml | 1 + config/REPST/BJTaxi-outflow.yaml | 1 + .../REPST/BeijingAirQuality(Deprecated).yaml | 1 + config/REPST/METR-LA.yaml | 1 + config/REPST/NYCBike-inflow.yaml | 1 + config/REPST/NYCBike-outflow.yaml | 1 + config/REPST/PEMS-BAY.yaml | 1 + config/REPST/PEMS-BAY_paper.yaml | 1 + config/REPST/SolarEnergy.yaml | 5 +- model/AEPSA/aepsa.py | 251 ++++++++++++++++++ model/AEPSA/repst.py | 103 ------- model/model_selector.py | 2 +- utils/Download_data.py | 1 + 23 files changed, 664 insertions(+), 151 deletions(-) create mode 100644 config/AEPSA/AirQuality.yaml rename config/AEPSA/{PEMSD8.yaml => BJTaxi-Inflow.yaml} (82%) mode change 100755 => 100644 rename config/{REPST/PEMSD8.yaml => AEPSA/BJTaxi-outflow.yaml} (80%) mode change 100755 => 100644 create mode 100644 config/AEPSA/METR-LA.yaml create mode 100644 config/AEPSA/NYCBike-inflow.yaml create mode 100644 config/AEPSA/NYCBike-outflow.yaml create mode 100644 config/AEPSA/SolarEnergy.yaml create mode 100644 model/AEPSA/aepsa.py delete mode 100644 model/AEPSA/repst.py diff --git a/.vscode/launch.json b/.vscode/launch.json index c9f73d8..f8f45f8 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -3,9 +3,11 @@ // 悬停以查看现有属性的描述。 // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", + "configurations": [ + // STID 模型组 { - "name": "STID_PEMS-BAY", + "name": "STID: PEMS-BAY", "type": "debugpy", "request": "launch", "program": "run.py", @@ -13,7 +15,7 @@ "args": "--config ./config/STID/PEMS-BAY.yaml" }, { - "name": "STID_PEMSD4", + "name": "STID: PEMSD4", "type": "debugpy", "request": "launch", "program": "run.py", @@ -21,15 +23,7 @@ "args": "--config ./config/STID/PEMSD4.yaml" }, { - "name": "REPST", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/PEMSD8.yaml" - }, - { - "name": "STID-BJTaxi-InFlow", + "name": "STID: BJTaxi-InFlow", "type": "debugpy", "request": "launch", "program": "run.py", @@ -37,7 +31,7 @@ "args": "--config ./config/STID/BJTaxi_Inflow.yaml" }, { - "name": "STID-BJTaxi-OutFlow", + "name": "STID: BJTaxi-OutFlow", "type": "debugpy", "request": "launch", "program": "run.py", @@ -45,7 +39,7 @@ "args": "--config ./config/STID/BJTaxi_Outflow.yaml" }, { - "name": "STID-NYCBike-InFlow", + "name": "STID: NYCBike-InFlow", "type": "debugpy", "request": "launch", "program": "run.py", @@ -53,7 +47,7 @@ "args": "--config ./config/STID/NYCBike_Inflow.yaml" }, { - "name": "STID-NYCBike-OutFlow", + "name": "STID: NYCBike-OutFlow", "type": "debugpy", "request": "launch", "program": "run.py", @@ -61,15 +55,25 @@ "args": "--config ./config/STID/NYCBike_Outflow.yaml" }, { - "name": "STID-SolarEnergy", + "name": "STID: SolarEnergy", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", "args": "--config ./config/STID/SolarEnergy.yaml" }, + + // REPST 模型组 { - "name": "REPST-BJTaxi-InFlow", + "name": "REPST: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/PEMSD8.yaml" + }, + { + "name": "REPST: BJTaxi-InFlow", "type": "debugpy", "request": "launch", "program": "run.py", @@ -77,7 +81,7 @@ "args": "--config ./config/REPST/BJTaxi-Inflow.yaml" }, { - "name": "REPST-NYCBike-outflow", + "name": "REPST: NYCBike-outflow", "type": "debugpy", "request": "launch", "program": "run.py", @@ -85,7 +89,7 @@ "args": "--config ./config/REPST/NYCBike-outflow.yaml" }, { - "name": "REPST-NYCBike-inflow", + "name": "REPST: NYCBike-inflow", "type": "debugpy", "request": "launch", "program": "run.py", @@ -93,7 +97,7 @@ "args": "--config ./config/REPST/NYCBike-inflow.yaml" }, { - "name": "REPST-PEMSBAY", + "name": "REPST: PEMS-BAY", "type": "debugpy", "request": "launch", "program": "run.py", @@ -101,7 +105,7 @@ "args": "--config ./config/REPST/PEMS-BAY.yaml" }, { - "name": "REPST-METR", + "name": "REPST: METR-LA", "type": "debugpy", "request": "launch", "program": "run.py", @@ -109,7 +113,7 @@ "args": "--config ./config/REPST/METR-LA.yaml" }, { - "name": "REPST-Solar", + "name": "REPST: SolarEnergy", "type": "debugpy", "request": "launch", "program": "run.py", @@ -117,7 +121,7 @@ "args": "--config ./config/REPST/SolarEnergy.yaml" }, { - "name": "BeijingAirQuality", + "name": "REPST: BeijingAirQuality", "type": "debugpy", "request": "launch", "program": "run.py", @@ -125,20 +129,78 @@ "args": "--config ./config/REPST/BeijingAirQuality.yaml" }, { - "name": "AirQuality", + "name": "REPST: AirQuality", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", "args": "--config ./config/REPST/AirQuality.yaml" }, + + // AEPSA 模型组 { - "name": "AEPSA-PEMSBAY", + "name": "AEPSA: PEMS-BAY", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", "args": "--config ./config/AEPSA/PEMS-BAY.yaml" + }, + { + "name": "AEPSA: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AEPSA/METR-LA.yaml" + }, + { + "name": "AEPSA: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AEPSA/AirQuality.yaml" + }, + { + "name": "AEPSA: BJTaxi-Inflow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AEPSA/BJTaxi-Inflow.yaml" + }, + { + "name": "AEPSA: BJTaxi-outflow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AEPSA/BJTaxi-outflow.yaml" + }, + { + "name": "AEPSA: NYCBike-inflow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AEPSA/NYCBike-inflow.yaml" + }, + { + "name": "AEPSA: NYCBike-outflow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AEPSA/NYCBike-outflow.yaml" + }, + { + "name": "AEPSA: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AEPSA/SolarEnergy.yaml" } ] } \ No newline at end of file diff --git a/config/AEPSA/AirQuality.yaml b/config/AEPSA/AirQuality.yaml new file mode 100644 index 0000000..652a544 --- /dev/null +++ b/config/AEPSA/AirQuality.yaml @@ -0,0 +1,58 @@ +basic: + dataset: "AirQuality" + mode : "train" + device : "cuda:0" + model: "AEPSA" + seed: 2023 + +data: + add_day_in_week: true + add_time_in_day: true + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 6 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 6 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + weight_decay: 0 + debug: false + output_dim: 6 + log_step: 100 + plot: false + mae_thresh: None + mape_thresh: 0.001 diff --git a/config/AEPSA/PEMSD8.yaml b/config/AEPSA/BJTaxi-Inflow.yaml old mode 100755 new mode 100644 similarity index 82% rename from config/AEPSA/PEMSD8.yaml rename to config/AEPSA/BJTaxi-Inflow.yaml index 363d73a..771119b --- a/config/AEPSA/PEMSD8.yaml +++ b/config/AEPSA/BJTaxi-Inflow.yaml @@ -1,8 +1,9 @@ basic: - dataset: "PEMSD8" + dataset: "BJTaxi-Inflow" mode : "train" device : "cuda:0" model: "AEPSA" + seed: 2023 data: add_day_in_week: true @@ -13,14 +14,14 @@ data: horizon: 12 lag: 12 normalizer: std - num_nodes: 170 - steps_per_day: 288 + num_nodes: 142 + steps_per_day: 48 test_ratio: 0.2 tod: false val_ratio: 0.2 sample: 1 input_dim: 1 - batch_size: 64 + batch_size: 32 model: pred_len: 12 @@ -33,9 +34,11 @@ model: gpt_path: ./GPT-2 d_model: 64 n_heads: 1 + input_dim: 1 + word_num: 1000 train: - batch_size: 64 + batch_size: 32 early_stop: true early_stop_patience: 15 epochs: 100 @@ -46,13 +49,10 @@ train: lr_decay_step: "5,20,40,70" lr_init: 0.003 max_grad_norm: 5 - real_value: true - seed: 12 weight_decay: 0 debug: false output_dim: 1 - log_step: 2000 + log_step: 100 plot: false mae_thresh: None mape_thresh: 0.001 - diff --git a/config/REPST/PEMSD8.yaml b/config/AEPSA/BJTaxi-outflow.yaml old mode 100755 new mode 100644 similarity index 80% rename from config/REPST/PEMSD8.yaml rename to config/AEPSA/BJTaxi-outflow.yaml index 3663a32..936b1a3 --- a/config/REPST/PEMSD8.yaml +++ b/config/AEPSA/BJTaxi-outflow.yaml @@ -1,8 +1,9 @@ basic: - dataset: "PEMSD8" + dataset: "BJTaxi-outflow" mode : "train" device : "cuda:0" - model: "REPST" + model: "AEPSA" + seed: 2023 data: add_day_in_week: true @@ -13,14 +14,14 @@ data: horizon: 12 lag: 12 normalizer: std - num_nodes: 170 - steps_per_day: 288 + num_nodes: 142 + steps_per_day: 48 test_ratio: 0.2 tod: false val_ratio: 0.2 sample: 1 input_dim: 1 - batch_size: 64 + batch_size: 32 model: pred_len: 12 @@ -33,9 +34,11 @@ model: gpt_path: ./GPT-2 d_model: 64 n_heads: 1 + input_dim: 1 + word_num: 1000 train: - batch_size: 64 + batch_size: 32 early_stop: true early_stop_patience: 15 epochs: 100 @@ -46,13 +49,10 @@ train: lr_decay_step: "5,20,40,70" lr_init: 0.003 max_grad_norm: 5 - real_value: true - seed: 12 weight_decay: 0 debug: false output_dim: 1 - log_step: 2000 + log_step: 100 plot: false mae_thresh: None mape_thresh: 0.001 - diff --git a/config/AEPSA/METR-LA.yaml b/config/AEPSA/METR-LA.yaml new file mode 100644 index 0000000..b24f448 --- /dev/null +++ b/config/AEPSA/METR-LA.yaml @@ -0,0 +1,59 @@ +basic: + dataset: "METR-LA" + mode : "train" + device : "cuda:0" + model: "AEPSA" + seed: 2023 + +data: + add_day_in_week: true + add_time_in_day: true + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 1000 + plot: false + mae_thresh: None + mape_thresh: 0.001 diff --git a/config/AEPSA/NYCBike-inflow.yaml b/config/AEPSA/NYCBike-inflow.yaml new file mode 100644 index 0000000..ee0bd06 --- /dev/null +++ b/config/AEPSA/NYCBike-inflow.yaml @@ -0,0 +1,58 @@ +basic: + dataset: "NYCBike-inflow" + mode : "train" + device : "cuda:0" + model: "AEPSA" + seed: 2023 + +data: + add_day_in_week: true + add_time_in_day: true + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 12 + lag: 12 + normalizer: std + num_nodes: 200 + steps_per_day: 24 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 32 + +model: + pred_len: 12 + seq_len: 12 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + +train: + batch_size: 32 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 100 + plot: false + mae_thresh: None + mape_thresh: 0.001 diff --git a/config/AEPSA/NYCBike-outflow.yaml b/config/AEPSA/NYCBike-outflow.yaml new file mode 100644 index 0000000..a620e0e --- /dev/null +++ b/config/AEPSA/NYCBike-outflow.yaml @@ -0,0 +1,58 @@ +basic: + dataset: "NYCBike-outflow" + mode : "train" + device : "cuda:0" + model: "AEPSA" + seed: 2023 + +data: + add_day_in_week: true + add_time_in_day: true + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 12 + lag: 12 + normalizer: std + num_nodes: 200 + steps_per_day: 24 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 32 + +model: + pred_len: 12 + seq_len: 12 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + +train: + batch_size: 32 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 100 + plot: false + mae_thresh: None + mape_thresh: 0.001 diff --git a/config/AEPSA/PEMS-BAY.yaml b/config/AEPSA/PEMS-BAY.yaml index 9330f9b..99d58a1 100755 --- a/config/AEPSA/PEMS-BAY.yaml +++ b/config/AEPSA/PEMS-BAY.yaml @@ -3,6 +3,7 @@ basic: mode : "train" device : "cuda:0" model: "AEPSA" + seed: 2023 data: add_day_in_week: true @@ -34,6 +35,7 @@ model: d_model: 64 n_heads: 1 input_dim: 1 + word_num: 1000 train: batch_size: 16 @@ -47,8 +49,6 @@ train: lr_decay_step: "5,20,40,70" lr_init: 0.003 max_grad_norm: 5 - real_value: true - seed: 12 weight_decay: 0 debug: false output_dim: 1 diff --git a/config/AEPSA/SolarEnergy.yaml b/config/AEPSA/SolarEnergy.yaml new file mode 100644 index 0000000..d6cd736 --- /dev/null +++ b/config/AEPSA/SolarEnergy.yaml @@ -0,0 +1,59 @@ +basic: + dataset: "SolarEnergy" + mode : "train" + device : "cuda:0" + model: "AEPSA" + seed: 2023 + +data: + add_day_in_week: true + add_time_in_day: true + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 64 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + num_nodes: 137 + +train: + batch_size: 64 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 100 + plot: false + mae_thresh: None + mape_thresh: 0.001 diff --git a/config/REPST/AirQuality.yaml b/config/REPST/AirQuality.yaml index b0683ae..fa716eb 100755 --- a/config/REPST/AirQuality.yaml +++ b/config/REPST/AirQuality.yaml @@ -37,6 +37,7 @@ model: input_dim: 6 output_dim: 3 word_num: 1000 + num_nodes: 35 train: batch_size: 16 diff --git a/config/REPST/BJTaxi-Inflow.yaml b/config/REPST/BJTaxi-Inflow.yaml index 37577a2..299711f 100755 --- a/config/REPST/BJTaxi-Inflow.yaml +++ b/config/REPST/BJTaxi-Inflow.yaml @@ -36,6 +36,7 @@ model: n_heads: 1 input_dim: 1 word_num: 1000 + num_nodes: 1024 train: batch_size: 16 diff --git a/config/REPST/BJTaxi-outflow.yaml b/config/REPST/BJTaxi-outflow.yaml index f9994ee..a453a63 100755 --- a/config/REPST/BJTaxi-outflow.yaml +++ b/config/REPST/BJTaxi-outflow.yaml @@ -36,6 +36,7 @@ model: n_heads: 1 input_dim: 1 word_num: 1000 + num_nodes: 1024 train: batch_size: 16 diff --git a/config/REPST/BeijingAirQuality(Deprecated).yaml b/config/REPST/BeijingAirQuality(Deprecated).yaml index 595c971..e6559f3 100755 --- a/config/REPST/BeijingAirQuality(Deprecated).yaml +++ b/config/REPST/BeijingAirQuality(Deprecated).yaml @@ -37,6 +37,7 @@ model: input_dim: 3 output_dim: 3 word_num: 1000 + num_nodes: 7 train: batch_size: 16 diff --git a/config/REPST/METR-LA.yaml b/config/REPST/METR-LA.yaml index 1e3e29d..f6653ad 100755 --- a/config/REPST/METR-LA.yaml +++ b/config/REPST/METR-LA.yaml @@ -36,6 +36,7 @@ model: n_heads: 1 input_dim: 1 word_num: 1000 + num_nodes: 207 train: batch_size: 16 diff --git a/config/REPST/NYCBike-inflow.yaml b/config/REPST/NYCBike-inflow.yaml index c7b0bd7..04b1553 100755 --- a/config/REPST/NYCBike-inflow.yaml +++ b/config/REPST/NYCBike-inflow.yaml @@ -36,6 +36,7 @@ model: n_heads: 1 input_dim: 1 word_num: 1000 + num_nodes: 128 train: batch_size: 16 diff --git a/config/REPST/NYCBike-outflow.yaml b/config/REPST/NYCBike-outflow.yaml index a03e3c8..dd9ca48 100755 --- a/config/REPST/NYCBike-outflow.yaml +++ b/config/REPST/NYCBike-outflow.yaml @@ -36,6 +36,7 @@ model: n_heads: 1 input_dim: 1 word_num: 1000 + num_nodes: 128 train: batch_size: 16 diff --git a/config/REPST/PEMS-BAY.yaml b/config/REPST/PEMS-BAY.yaml index 60eb800..dd1b02f 100755 --- a/config/REPST/PEMS-BAY.yaml +++ b/config/REPST/PEMS-BAY.yaml @@ -36,6 +36,7 @@ model: n_heads: 1 input_dim: 1 word_num: 1000 + num_nodes: 325 train: batch_size: 16 diff --git a/config/REPST/PEMS-BAY_paper.yaml b/config/REPST/PEMS-BAY_paper.yaml index a0db8c9..540a10e 100755 --- a/config/REPST/PEMS-BAY_paper.yaml +++ b/config/REPST/PEMS-BAY_paper.yaml @@ -35,6 +35,7 @@ model: n_heads: 1 input_dim: 1 t_max: 5 + num_nodes: 325 train: batch_size: 16 diff --git a/config/REPST/SolarEnergy.yaml b/config/REPST/SolarEnergy.yaml index 282c929..73af14a 100755 --- a/config/REPST/SolarEnergy.yaml +++ b/config/REPST/SolarEnergy.yaml @@ -21,7 +21,7 @@ data: val_ratio: 0.2 sample: 1 input_dim: 1 - batch_size: 16 + batch_size: 64 model: pred_len: 24 @@ -36,9 +36,10 @@ model: n_heads: 1 input_dim: 1 word_num: 1000 + num_nodes: 137 train: - batch_size: 16 + batch_size: 64 early_stop: true early_stop_patience: 15 epochs: 100 diff --git a/model/AEPSA/aepsa.py b/model/AEPSA/aepsa.py new file mode 100644 index 0000000..409b585 --- /dev/null +++ b/model/AEPSA/aepsa.py @@ -0,0 +1,251 @@ +import torch +import torch.nn as nn +from transformers.models.gpt2.modeling_gpt2 import GPT2Model +from einops import rearrange +from model.AEPSA.normalizer import GumbelSoftmax +from model.AEPSA.reprogramming import PatchEmbedding, ReprogrammingLayer +import torch.nn.functional as F + +class DynamicGraphEnhancer(nn.Module): + """ + 动态图增强器,基于节点嵌入自动生成图结构 + 参考DDGCRN的设计,使用节点嵌入和特征信息动态计算邻接矩阵 + """ + def __init__(self, num_nodes, in_dim, embed_dim=10): + super().__init__() + self.num_nodes = num_nodes + self.embed_dim = embed_dim + + # 节点嵌入参数 + self.node_embeddings = nn.Parameter( + torch.randn(num_nodes, embed_dim), requires_grad=True + ) + + # 特征转换层,用于生成动态调整的嵌入 + self.feature_transform = nn.Sequential( + nn.Linear(in_dim, 16), + nn.Sigmoid(), + nn.Linear(16, 2), + nn.Sigmoid(), + nn.Linear(2, embed_dim) + ) + + # 注册单位矩阵作为固定的支持矩阵 + self.register_buffer("eye", torch.eye(num_nodes)) + + def get_laplacian(self, graph, I, normalize=True): + """ + 计算归一化拉普拉斯矩阵 + """ + # 计算度矩阵的逆平方根 + D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) + D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题 + + if normalize: + return torch.matmul(torch.matmul(D_inv, graph), D_inv) + else: + return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) + + def forward(self, X): + """ + X: 输入特征 [B, N, D] + 返回: 动态生成的归一化拉普拉斯矩阵 [B, N, N] + """ + batch_size = X.size(0) + laplacians = [] + + # 获取单位矩阵 + I = self.eye.to(X.device) + + for b in range(batch_size): + # 使用特征转换层生成动态嵌入调整因子 + filt = self.feature_transform(X[b]) # [N, embed_dim] + + # 计算节点嵌入向量 + nodevec = torch.tanh(self.node_embeddings * filt) + + # 通过节点嵌入的点积计算邻接矩阵 + adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) + + # 计算归一化拉普拉斯矩阵 + laplacian = self.get_laplacian(adj, I) + laplacians.append(laplacian) + + return torch.stack(laplacians, dim=0) + +class GraphEnhancedEncoder(nn.Module): + """ + 基于Chebyshev多项式和动态拉普拉斯矩阵的图增强编码器 + 用于将动态图结构信息整合到特征编码中 + """ + def __init__(self, K=3, in_dim=64, hidden_dim=32, num_nodes=325, embed_dim=10, device='cpu'): + super().__init__() + self.K = K # Chebyshev多项式阶数 + self.in_dim = in_dim + self.hidden_dim = hidden_dim + self.device = device + + # 动态图增强器 + self.graph_enhancer = DynamicGraphEnhancer(num_nodes, in_dim, embed_dim) + + # 谱系数 α_k (可学习参数) + self.alpha = nn.Parameter(torch.randn(K + 1, 1)) + + # 传播权重 W_k (可学习参数) + self.W = nn.ParameterList([ + nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1) + ]) + + self.to(device) + + def chebyshev_polynomials(self, L_tilde, X): + """递归计算 [T_0(L_tilde)X, ..., T_K(L_tilde)X]""" + T_k_list = [X] + if self.K >= 1: + T_k_list.append(torch.matmul(L_tilde, X)) + for k in range(2, self.K + 1): + T_k_list.append(2 * torch.matmul(L_tilde, T_k_list[-1]) - T_k_list[-2]) + return T_k_list + + def forward(self, X): + """ + X: 输入特征 [B, N, D] + 返回: 增强后的特征 [B, N, hidden_dim*(K+1)] + """ + batch_size, num_nodes, _ = X.shape + enhanced_features = [] + + # 动态生成拉普拉斯矩阵 + laplacians = self.graph_enhancer(X) + + for b in range(batch_size): + L = laplacians[b] + + # 特征值缩放 + try: + lambda_max = torch.linalg.eigvalsh(L).max().real + # 避免除零问题 + if lambda_max < 1e-6: + lambda_max = 1.0 + L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device) + except: + # 如果计算特征值失败,使用单位矩阵 + L_tilde = torch.eye(num_nodes, device=X.device) + + # 计算Chebyshev多项式展开 + T_k_list = self.chebyshev_polynomials(L_tilde, X[b]) + H_list = [] + + # 应用传播权重 + for k in range(self.K + 1): + H_k = torch.matmul(T_k_list[k], self.W[k]) + H_list.append(H_k) + + # 拼接所有尺度的特征 + X_enhanced = torch.cat(H_list, dim=-1) # [N, hidden_dim*(K+1)] + enhanced_features.append(X_enhanced) + + return torch.stack(enhanced_features, dim=0) + +class AEPSA(nn.Module): + + def __init__(self, configs): + super(AEPSA, self).__init__() + self.device = configs['device'] + self.pred_len = configs['pred_len'] + self.seq_len = configs['seq_len'] + self.patch_len = configs['patch_len'] + self.input_dim = configs['input_dim'] + self.stride = configs['stride'] + self.dropout = configs['dropout'] + self.gpt_layers = configs['gpt_layers'] + self.d_ff = configs['d_ff'] + self.gpt_path = configs['gpt_path'] + self.num_nodes = configs.get('num_nodes', 325) # 添加节点数量配置 + + self.word_choice = GumbelSoftmax(configs['word_num']) + + self.d_model = configs['d_model'] + self.n_heads = configs['n_heads'] + self.d_keys = None + self.d_llm = 768 + + self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2) + self.head_nf = self.d_ff * self.patch_nums + + # 词嵌入 + self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim) + + # GPT2初始化 + self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) + self.gpts.h = self.gpts.h[:self.gpt_layers] + self.gpts.apply(self.reset_parameters) + + self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) + self.vocab_size = self.word_embeddings.shape[0] + self.mapping_layer = nn.Linear(self.vocab_size, 1) + self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm) + + # 添加动态图增强编码器 + self.graph_encoder = GraphEnhancedEncoder( + K=configs.get('chebyshev_order', 3), + in_dim=self.d_model, + hidden_dim=configs.get('graph_hidden_dim', 32), + num_nodes=self.num_nodes, + embed_dim=configs.get('graph_embed_dim', 10), + device=self.device + ) + + # 特征融合层 + self.feature_fusion = nn.Linear( + self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), + self.d_model + ) + + self.out_mlp = nn.Sequential( + nn.Linear(self.d_llm, 128), + nn.ReLU(), + nn.Linear(128, self.pred_len) + ) + + for i, (name, param) in enumerate(self.gpts.named_parameters()): + if 'wpe' in name: + param.requires_grad = True + else: + param.requires_grad = False + + def reset_parameters(self, module): + if hasattr(module, 'weight') and module.weight is not None: + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if hasattr(module, 'bias') and module.bias is not None: + torch.nn.init.zeros_(module.bias) + + def forward(self, x): + """ + x: 输入数据 [B, T, N, C] + 自动生成图结构,无需外部提供邻接矩阵 + """ + x = x[..., :1] + x_enc = rearrange(x, 'b t n c -> b n c t') + enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C) + # 应用图增强编码器(自动生成图结构) + graph_enhanced = self.graph_encoder(enc_out) + # 保持相同的维度 + + # 特征融合 - 现在两个张量都是三维的 [B, N, d_model] + enc_out = torch.cat([enc_out, graph_enhanced], dim=-1) + enc_out = self.feature_fusion(enc_out) + + self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) + masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0)) + source_embeddings = self.word_embeddings[masks==1] + + enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) + enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state + + dec_out = self.out_mlp(enc_out) + outputs = dec_out.unsqueeze(dim=-1) + outputs = outputs.repeat(1, 1, 1, n_vars) + outputs = outputs.permute(0,2,1,3) + + return outputs diff --git a/model/AEPSA/repst.py b/model/AEPSA/repst.py deleted file mode 100644 index 53a6046..0000000 --- a/model/AEPSA/repst.py +++ /dev/null @@ -1,103 +0,0 @@ -import torch -import torch.nn as nn -from transformers.models.gpt2.modeling_gpt2 import GPT2Model -from einops import rearrange -from model.REPST.normalizer import GumbelSoftmax -from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer - -class repst(nn.Module): - - def __init__(self, configs): - super(repst, self).__init__() - self.device = configs['device'] - self.pred_len = configs['pred_len'] - self.seq_len = configs['seq_len'] - self.patch_len = configs['patch_len'] - self.input_dim = configs['input_dim'] - self.stride = configs['stride'] - self.dropout = configs['dropout'] - self.gpt_layers = configs['gpt_layers'] - self.d_ff = configs['d_ff'] - self.gpt_path = configs['gpt_path'] - - self.word_choice = GumbelSoftmax(configs['word_num']) - - self.d_model = configs['d_model'] - self.n_heads = configs['n_heads'] - self.d_keys = None - self.d_llm = 768 - - self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2) - self.head_nf = self.d_ff * self.patch_nums - - # 词嵌入 - self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim) - - # GPT2初始化 - self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) - self.gpts.h = self.gpts.h[:self.gpt_layers] - self.gpts.apply(self.reset_parameters) - - self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) - self.vocab_size = self.word_embeddings.shape[0] - self.mapping_layer = nn.Linear(self.vocab_size, 1) - self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm) - - self.out_mlp = nn.Sequential( - nn.Linear(self.d_llm, 128), - nn.ReLU(), - nn.Linear(128, self.pred_len) - ) - - for i, (name, param) in enumerate(self.gpts.named_parameters()): - if 'wpe' in name: - param.requires_grad = True - else: - param.requires_grad = False - - def reset_parameters(self, module): - if hasattr(module, 'weight') and module.weight is not None: - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - if hasattr(module, 'bias') and module.bias is not None: - torch.nn.init.zeros_(module.bias) - - def forward(self, x): - x = x[..., :1] - x_enc = rearrange(x, 'b t n c -> b n c t') - enc_out, n_vars = self.patch_embedding(x_enc) - self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) - masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0)) - source_embeddings = self.word_embeddings[masks==1] - - enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) - enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state - - dec_out = self.out_mlp(enc_out) - outputs = dec_out.unsqueeze(dim=-1) - outputs = outputs.repeat(1, 1, 1, n_vars) - outputs = outputs.permute(0,2,1,3) - - return outputs - -if __name__ == '__main__': - configs = { - 'device': 'cuda:0', - 'pred_len': 24, - 'seq_len': 24, - 'patch_len': 6, - 'stride': 7, - 'dropout': 0.2, - 'gpt_layers': 9, - 'd_ff': 128, - 'gpt_path': './GPT-2', - 'd_model': 64, - 'n_heads': 1, - 'input_dim': 1 - } - model = repst(configs) - x = torch.randn(16, 24, 325, 1) - y = model(x) - - print(y.shape) - - diff --git a/model/model_selector.py b/model/model_selector.py index 11edccd..5043202 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -23,7 +23,7 @@ from model.ST_SSL.ST_SSL import STSSLModel from model.STGNRDE.Make_model import make_model as make_nrde_model from model.STAWnet.STAWnet import STAWnet from model.REPST.repst import repst as REPST -from model.AEPSA.repst import repst as AEPSA +from model.AEPSA.aepsa import AEPSA as AEPSA def model_selector(config): diff --git a/utils/Download_data.py b/utils/Download_data.py index f31b937..985b44d 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -160,6 +160,7 @@ def check_and_download_data(): file_path = f"Datasets/TaxiBJ/{file}" download_github_data(file_path, taxi_bj_floder) # 下载后更新缺失列表 + # download_and_extract("http://code.zhang-heng.com/static/BeijingTaxi.7z", data_dir) missing_list = detect_data_integrity(data_dir, file_tree) # 检查并下载TaxiBJ数据