From 440cb6936bea4e1ba4f855959f72d11787862c2d Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 3 Dec 2025 17:12:46 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=BC=E5=AE=B9STAEFormer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/AEPSA/AirQuality.yaml | 4 +-- config/STAEFormer/AirQuality.yaml | 10 +++--- config/STAEFormer/PEMS-BAY.yaml | 58 ++++++++++++++++++++++++++++++ config/STAEFormer/SolarEnergy.yaml | 8 ++--- model/STAEFormer/STAEFormer.py | 14 ++++---- run_tests.sh | 2 +- 6 files changed, 78 insertions(+), 18 deletions(-) create mode 100644 config/STAEFormer/PEMS-BAY.yaml diff --git a/config/AEPSA/AirQuality.yaml b/config/AEPSA/AirQuality.yaml index c2d905a..d6061d9 100644 --- a/config/AEPSA/AirQuality.yaml +++ b/config/AEPSA/AirQuality.yaml @@ -13,7 +13,7 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 12 + num_nodes: 35 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 @@ -26,7 +26,7 @@ model: gpt_path: ./GPT-2 input_dim: 6 n_heads: 1 - num_nodes: 12 + num_nodes: 35 patch_len: 6 pred_len: 24 seq_len: 24 diff --git a/config/STAEFormer/AirQuality.yaml b/config/STAEFormer/AirQuality.yaml index e5f07e8..b7956da 100644 --- a/config/STAEFormer/AirQuality.yaml +++ b/config/STAEFormer/AirQuality.yaml @@ -13,8 +13,8 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 12 - steps_per_day: 24 + num_nodes: 35 + steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 @@ -28,7 +28,7 @@ model: input_embedding_dim: 24 num_heads: 4 num_layers: 3 - num_nodes: 12 + num_nodes: 35 out_steps: 24 output_dim: 6 spatial_embedding_dim: 0 @@ -41,9 +41,9 @@ train: debug: false early_stop: true early_stop_patience: 15 - epochs: 300 + epochs: 100 grad_norm: false - log_step: 200 + log_step: 20000 loss_func: mae lr_decay: false lr_decay_rate: 0.3 diff --git a/config/STAEFormer/PEMS-BAY.yaml b/config/STAEFormer/PEMS-BAY.yaml new file mode 100644 index 0000000..353da67 --- /dev/null +++ b/config/STAEFormer/PEMS-BAY.yaml @@ -0,0 +1,58 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: STAEFormer + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + adaptive_embedding_dim: 80 + dow_embedding_dim: 24 + dropout: 0.1 + feed_forward_dim: 256 + in_steps: 24 + input_dim: 1 + input_embedding_dim: 24 + num_heads: 4 + num_layers: 3 + num_nodes: 325 + out_steps: 24 + output_dim: 1 + spatial_embedding_dim: 0 + steps_per_day: 288 + tod_embedding_dim: 24 + use_mixed_proj: true + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 300 + grad_norm: false + log_step: 200 + loss_func: mae + lr_decay: false + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: 0.0 + mape_thresh: 0.0 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 diff --git a/config/STAEFormer/SolarEnergy.yaml b/config/STAEFormer/SolarEnergy.yaml index fd97a63..c1151ca 100644 --- a/config/STAEFormer/SolarEnergy.yaml +++ b/config/STAEFormer/SolarEnergy.yaml @@ -24,13 +24,13 @@ model: dropout: 0.1 feed_forward_dim: 256 in_steps: 24 - input_dim: 137 + input_dim: 1 input_embedding_dim: 24 num_heads: 4 num_layers: 3 num_nodes: 137 out_steps: 24 - output_dim: 137 + output_dim: 1 spatial_embedding_dim: 0 steps_per_day: 24 tod_embedding_dim: 24 @@ -41,7 +41,7 @@ train: debug: false early_stop: true early_stop_patience: 15 - epochs: 300 + epochs: 100 grad_norm: false log_step: 200 loss_func: mae @@ -52,7 +52,7 @@ train: mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 - output_dim: 137 + output_dim: 1 plot: false real_value: true weight_decay: 0 diff --git a/model/STAEFormer/STAEFormer.py b/model/STAEFormer/STAEFormer.py index 63fdb01..91b8188 100755 --- a/model/STAEFormer/STAEFormer.py +++ b/model/STAEFormer/STAEFormer.py @@ -187,17 +187,19 @@ class STAEformer(nn.Module): batch_size = x.shape[0] if self.tod_embedding_dim > 0: - tod = x[..., 1] + tod = x[..., -2] if self.dow_embedding_dim > 0: - dow = x[..., 2] - x = x[..., 0:1] + dow = x[..., -1] + x = x[..., 0:self.input_dim] x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim) features = [x] if self.tod_embedding_dim > 0: - tod_emb = self.tod_embedding( - (tod * self.steps_per_day).long() - ) # (batch_size, in_steps, num_nodes, tod_embedding_dim) + # 确保索引在有效范围内 + tod_index = (tod * self.steps_per_day).long() + # 防止索引越界 + tod_index = torch.clamp(tod_index, 0, self.steps_per_day - 1) + tod_emb = self.tod_embedding(tod_index) # (batch_size, in_steps, num_nodes, tod_embedding_dim) features.append(tod_emb) if self.dow_embedding_dim > 0: dow_emb = self.dow_embedding( diff --git a/run_tests.sh b/run_tests.sh index a27a3bf..94080b6 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -1,7 +1,7 @@ #!/bin/bash # 设置默认模型名和数据集列表 -MODEL_NAME="GWN" +MODEL_NAME="STAEFormer" DATASETS=( "METR-LA" "PEMS-BAY"