From 560d24e5a86a1eb67a62303e6f61df692faf4f9b Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 10 Dec 2025 10:39:41 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0v2=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 4 +- config/ASTRA/v2_AirQuality.yaml | 54 +++++++++++++++++++ ...Taxi-Inflow.yaml => v2_BJTaxi-InFlow.yaml} | 2 +- ...xi-outflow.yaml => v2_BJTaxi-OutFlow.yaml} | 2 +- ...ike-inflow.yaml => v2_NYCBike-InFlow.yaml} | 4 +- ...e-outflow.yaml => v2_NYCBike-OutFlow.yaml} | 4 +- model/ASTRA/astrav2.py | 8 +-- 7 files changed, 67 insertions(+), 11 deletions(-) create mode 100644 config/ASTRA/v2_AirQuality.yaml rename config/ASTRA/{BJTaxi-Inflow.yaml => v2_BJTaxi-InFlow.yaml} (97%) rename config/ASTRA/{BJTaxi-outflow.yaml => v2_BJTaxi-OutFlow.yaml} (97%) rename config/ASTRA/{NYCBike-inflow.yaml => v2_NYCBike-InFlow.yaml} (95%) rename config/ASTRA/{NYCBike-outflow.yaml => v2_NYCBike-OutFlow.yaml} (95%) diff --git a/.vscode/launch.json b/.vscode/launch.json index cc2c023..54aad8a 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -219,12 +219,12 @@ "args": "--config ./config/ASTRA/SolarEnergy.yaml" }, { - "name": "ASTRA_v2: METR-LA", + "name": "ASTRA_v2: AirQuality", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/ASTRA/v2_METR-LA.yaml" + "args": "--config ./config/ASTRA/v2_AirQuality.yaml" }, { "name": "ASTRA_v2: SolarEnergy", diff --git a/config/ASTRA/v2_AirQuality.yaml b/config/ASTRA/v2_AirQuality.yaml new file mode 100644 index 0000000..10796d2 --- /dev/null +++ b/config/ASTRA/v2_AirQuality.yaml @@ -0,0 +1,54 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: ASTRA_v2 + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 6 + n_heads: 1 + num_nodes: 35 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + weight_decay: 0 diff --git a/config/ASTRA/BJTaxi-Inflow.yaml b/config/ASTRA/v2_BJTaxi-InFlow.yaml similarity index 97% rename from config/ASTRA/BJTaxi-Inflow.yaml rename to config/ASTRA/v2_BJTaxi-InFlow.yaml index c2766bb..d1cc5ea 100644 --- a/config/ASTRA/BJTaxi-Inflow.yaml +++ b/config/ASTRA/v2_BJTaxi-InFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-InFlow device: cuda:0 mode: train - model: ASTRA + model: ASTRA_v2 seed: 2023 data: diff --git a/config/ASTRA/BJTaxi-outflow.yaml b/config/ASTRA/v2_BJTaxi-OutFlow.yaml similarity index 97% rename from config/ASTRA/BJTaxi-outflow.yaml rename to config/ASTRA/v2_BJTaxi-OutFlow.yaml index ee570f3..d6e0723 100644 --- a/config/ASTRA/BJTaxi-outflow.yaml +++ b/config/ASTRA/v2_BJTaxi-OutFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-OutFlow device: cuda:0 mode: train - model: ASTRA + model: ASTRA_v2 seed: 2023 data: diff --git a/config/ASTRA/NYCBike-inflow.yaml b/config/ASTRA/v2_NYCBike-InFlow.yaml similarity index 95% rename from config/ASTRA/NYCBike-inflow.yaml rename to config/ASTRA/v2_NYCBike-InFlow.yaml index 5431fba..de5b6a1 100644 --- a/config/ASTRA/NYCBike-inflow.yaml +++ b/config/ASTRA/v2_NYCBike-InFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-InFlow device: cuda:0 mode: train - model: ASTRA + model: ASTRA_v2 seed: 2023 data: @@ -14,7 +14,7 @@ data: lag: 24 normalizer: std num_nodes: 128 - steps_per_day: 24 + steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 diff --git a/config/ASTRA/NYCBike-outflow.yaml b/config/ASTRA/v2_NYCBike-OutFlow.yaml similarity index 95% rename from config/ASTRA/NYCBike-outflow.yaml rename to config/ASTRA/v2_NYCBike-OutFlow.yaml index 194c330..dda718d 100644 --- a/config/ASTRA/NYCBike-outflow.yaml +++ b/config/ASTRA/v2_NYCBike-OutFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-OutFlow device: cuda:0 mode: train - model: ASTRA + model: ASTRA_v2 seed: 2023 data: @@ -14,7 +14,7 @@ data: lag: 24 normalizer: std num_nodes: 128 - steps_per_day: 24 + steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 diff --git a/model/ASTRA/astrav2.py b/model/ASTRA/astrav2.py index 79a1330..6a47206 100644 --- a/model/ASTRA/astrav2.py +++ b/model/ASTRA/astrav2.py @@ -184,7 +184,7 @@ class ASTRA(nn.Module): def forward(self, x): # 数据处理 - x = x[..., :1] # [B,T,N,1] + x = x[..., :self.input_dim] # [B,T,N,1] x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T] # 图编码 @@ -202,7 +202,9 @@ class ASTRA(nn.Module): dec_out = self.out_mlp(enc_out) # [B,N,pred_len] # 维度调整 - outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1] - outputs = outputs.permute(0, 2, 1, 3) # [B,pred_len,N,1] + dec_out = self.out_mlp(enc_out) + outputs = dec_out.unsqueeze(dim=-1) + outputs = outputs.repeat(1, 1, 1, self.input_dim) + outputs = outputs.permute(0,2,1,3) return outputs \ No newline at end of file