更新v2配置

This commit is contained in:
czzhangheng 2025-12-10 10:39:41 +08:00
parent 5c2380ae21
commit 560d24e5a8
7 changed files with 67 additions and 11 deletions

4
.vscode/launch.json vendored
View File

@ -219,12 +219,12 @@
"args": "--config ./config/ASTRA/SolarEnergy.yaml" "args": "--config ./config/ASTRA/SolarEnergy.yaml"
}, },
{ {
"name": "ASTRA_v2: METR-LA", "name": "ASTRA_v2: AirQuality",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": "--config ./config/ASTRA/v2_METR-LA.yaml" "args": "--config ./config/ASTRA/v2_AirQuality.yaml"
}, },
{ {
"name": "ASTRA_v2: SolarEnergy", "name": "ASTRA_v2: SolarEnergy",

View File

@ -0,0 +1,54 @@
basic:
dataset: AirQuality
device: cuda:0
mode: train
model: ASTRA_v2
seed: 2023
data:
batch_size: 16
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 6
lag: 24
normalizer: std
num_nodes: 35
steps_per_day: 24
test_ratio: 0.2
val_ratio: 0.2
model:
d_ff: 128
d_model: 64
dropout: 0.2
gpt_layers: 9
gpt_path: ./GPT-2
input_dim: 6
n_heads: 1
num_nodes: 35
patch_len: 6
pred_len: 24
seq_len: 24
stride: 7
word_num: 1000
train:
batch_size: 16
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 100
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 6
plot: false
weight_decay: 0

View File

@ -2,7 +2,7 @@ basic:
dataset: BJTaxi-InFlow dataset: BJTaxi-InFlow
device: cuda:0 device: cuda:0
mode: train mode: train
model: ASTRA model: ASTRA_v2
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: BJTaxi-OutFlow dataset: BJTaxi-OutFlow
device: cuda:0 device: cuda:0
mode: train mode: train
model: ASTRA model: ASTRA_v2
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: NYCBike-InFlow dataset: NYCBike-InFlow
device: cuda:0 device: cuda:0
mode: train mode: train
model: ASTRA model: ASTRA_v2
seed: 2023 seed: 2023
data: data:
@ -14,7 +14,7 @@ data:
lag: 24 lag: 24
normalizer: std normalizer: std
num_nodes: 128 num_nodes: 128
steps_per_day: 24 steps_per_day: 48
test_ratio: 0.2 test_ratio: 0.2
val_ratio: 0.2 val_ratio: 0.2

View File

@ -2,7 +2,7 @@ basic:
dataset: NYCBike-OutFlow dataset: NYCBike-OutFlow
device: cuda:0 device: cuda:0
mode: train mode: train
model: ASTRA model: ASTRA_v2
seed: 2023 seed: 2023
data: data:
@ -14,7 +14,7 @@ data:
lag: 24 lag: 24
normalizer: std normalizer: std
num_nodes: 128 num_nodes: 128
steps_per_day: 24 steps_per_day: 48
test_ratio: 0.2 test_ratio: 0.2
val_ratio: 0.2 val_ratio: 0.2

View File

@ -184,7 +184,7 @@ class ASTRA(nn.Module):
def forward(self, x): def forward(self, x):
# 数据处理 # 数据处理
x = x[..., :1] # [B,T,N,1] x = x[..., :self.input_dim] # [B,T,N,1]
x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T] x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T]
# 图编码 # 图编码
@ -202,7 +202,9 @@ class ASTRA(nn.Module):
dec_out = self.out_mlp(enc_out) # [B,N,pred_len] dec_out = self.out_mlp(enc_out) # [B,N,pred_len]
# 维度调整 # 维度调整
outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1] dec_out = self.out_mlp(enc_out)
outputs = outputs.permute(0, 2, 1, 3) # [B,pred_len,N,1] outputs = dec_out.unsqueeze(dim=-1)
outputs = outputs.repeat(1, 1, 1, self.input_dim)
outputs = outputs.permute(0,2,1,3)
return outputs return outputs