AEPSA v0.1

This commit is contained in:
czzhangheng 2025-11-24 21:50:24 +08:00
parent 88daa53dd5
commit 9c50c30918
23 changed files with 664 additions and 151 deletions

110
.vscode/launch.json vendored
View File

@ -3,9 +3,11 @@
// //
// 访: https://go.microsoft.com/fwlink/?linkid=830387 // 访: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
// STID
{ {
"name": "STID_PEMS-BAY", "name": "STID: PEMS-BAY",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
@ -13,7 +15,7 @@
"args": "--config ./config/STID/PEMS-BAY.yaml" "args": "--config ./config/STID/PEMS-BAY.yaml"
}, },
{ {
"name": "STID_PEMSD4", "name": "STID: PEMSD4",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
@ -21,15 +23,7 @@
"args": "--config ./config/STID/PEMSD4.yaml" "args": "--config ./config/STID/PEMSD4.yaml"
}, },
{ {
"name": "REPST", "name": "STID: BJTaxi-InFlow",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/REPST/PEMSD8.yaml"
},
{
"name": "STID-BJTaxi-InFlow",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
@ -37,7 +31,7 @@
"args": "--config ./config/STID/BJTaxi_Inflow.yaml" "args": "--config ./config/STID/BJTaxi_Inflow.yaml"
}, },
{ {
"name": "STID-BJTaxi-OutFlow", "name": "STID: BJTaxi-OutFlow",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
@ -45,7 +39,7 @@
"args": "--config ./config/STID/BJTaxi_Outflow.yaml" "args": "--config ./config/STID/BJTaxi_Outflow.yaml"
}, },
{ {
"name": "STID-NYCBike-InFlow", "name": "STID: NYCBike-InFlow",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
@ -53,7 +47,7 @@
"args": "--config ./config/STID/NYCBike_Inflow.yaml" "args": "--config ./config/STID/NYCBike_Inflow.yaml"
}, },
{ {
"name": "STID-NYCBike-OutFlow", "name": "STID: NYCBike-OutFlow",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
@ -61,15 +55,25 @@
"args": "--config ./config/STID/NYCBike_Outflow.yaml" "args": "--config ./config/STID/NYCBike_Outflow.yaml"
}, },
{ {
"name": "STID-SolarEnergy", "name": "STID: SolarEnergy",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": "--config ./config/STID/SolarEnergy.yaml" "args": "--config ./config/STID/SolarEnergy.yaml"
}, },
// REPST
{ {
"name": "REPST-BJTaxi-InFlow", "name": "REPST: PEMSD8",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/REPST/PEMSD8.yaml"
},
{
"name": "REPST: BJTaxi-InFlow",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
@ -77,7 +81,7 @@
"args": "--config ./config/REPST/BJTaxi-Inflow.yaml" "args": "--config ./config/REPST/BJTaxi-Inflow.yaml"
}, },
{ {
"name": "REPST-NYCBike-outflow", "name": "REPST: NYCBike-outflow",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
@ -85,7 +89,7 @@
"args": "--config ./config/REPST/NYCBike-outflow.yaml" "args": "--config ./config/REPST/NYCBike-outflow.yaml"
}, },
{ {
"name": "REPST-NYCBike-inflow", "name": "REPST: NYCBike-inflow",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
@ -93,7 +97,7 @@
"args": "--config ./config/REPST/NYCBike-inflow.yaml" "args": "--config ./config/REPST/NYCBike-inflow.yaml"
}, },
{ {
"name": "REPST-PEMSBAY", "name": "REPST: PEMS-BAY",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
@ -101,7 +105,7 @@
"args": "--config ./config/REPST/PEMS-BAY.yaml" "args": "--config ./config/REPST/PEMS-BAY.yaml"
}, },
{ {
"name": "REPST-METR", "name": "REPST: METR-LA",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
@ -109,7 +113,7 @@
"args": "--config ./config/REPST/METR-LA.yaml" "args": "--config ./config/REPST/METR-LA.yaml"
}, },
{ {
"name": "REPST-Solar", "name": "REPST: SolarEnergy",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
@ -117,7 +121,7 @@
"args": "--config ./config/REPST/SolarEnergy.yaml" "args": "--config ./config/REPST/SolarEnergy.yaml"
}, },
{ {
"name": "BeijingAirQuality", "name": "REPST: BeijingAirQuality",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
@ -125,20 +129,78 @@
"args": "--config ./config/REPST/BeijingAirQuality.yaml" "args": "--config ./config/REPST/BeijingAirQuality.yaml"
}, },
{ {
"name": "AirQuality", "name": "REPST: AirQuality",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": "--config ./config/REPST/AirQuality.yaml" "args": "--config ./config/REPST/AirQuality.yaml"
}, },
// AEPSA
{ {
"name": "AEPSA-PEMSBAY", "name": "AEPSA: PEMS-BAY",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "run.py", "program": "run.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": "--config ./config/AEPSA/PEMS-BAY.yaml" "args": "--config ./config/AEPSA/PEMS-BAY.yaml"
},
{
"name": "AEPSA: METR-LA",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/AEPSA/METR-LA.yaml"
},
{
"name": "AEPSA: AirQuality",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/AEPSA/AirQuality.yaml"
},
{
"name": "AEPSA: BJTaxi-Inflow",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/AEPSA/BJTaxi-Inflow.yaml"
},
{
"name": "AEPSA: BJTaxi-outflow",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/AEPSA/BJTaxi-outflow.yaml"
},
{
"name": "AEPSA: NYCBike-inflow",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/AEPSA/NYCBike-inflow.yaml"
},
{
"name": "AEPSA: NYCBike-outflow",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/AEPSA/NYCBike-outflow.yaml"
},
{
"name": "AEPSA: SolarEnergy",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/AEPSA/SolarEnergy.yaml"
} }
] ]
} }

View File

@ -0,0 +1,58 @@
basic:
dataset: "AirQuality"
mode : "train"
device : "cuda:0"
model: "AEPSA"
seed: 2023
data:
add_day_in_week: true
add_time_in_day: true
column_wise: false
days_per_week: 7
default_graph: true
horizon: 24
lag: 24
normalizer: std
num_nodes: 35
steps_per_day: 24
test_ratio: 0.2
tod: false
val_ratio: 0.2
sample: 1
input_dim: 6
batch_size: 16
model:
pred_len: 24
seq_len: 24
patch_len: 6
stride: 7
dropout: 0.2
gpt_layers: 9
d_ff: 128
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
input_dim: 6
word_num: 1000
train:
batch_size: 16
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
weight_decay: 0
debug: false
output_dim: 6
log_step: 100
plot: false
mae_thresh: None
mape_thresh: 0.001

View File

@ -1,8 +1,9 @@
basic: basic:
dataset: "PEMSD8" dataset: "BJTaxi-Inflow"
mode : "train" mode : "train"
device : "cuda:0" device : "cuda:0"
model: "AEPSA" model: "AEPSA"
seed: 2023
data: data:
add_day_in_week: true add_day_in_week: true
@ -13,14 +14,14 @@ data:
horizon: 12 horizon: 12
lag: 12 lag: 12
normalizer: std normalizer: std
num_nodes: 170 num_nodes: 142
steps_per_day: 288 steps_per_day: 48
test_ratio: 0.2 test_ratio: 0.2
tod: false tod: false
val_ratio: 0.2 val_ratio: 0.2
sample: 1 sample: 1
input_dim: 1 input_dim: 1
batch_size: 64 batch_size: 32
model: model:
pred_len: 12 pred_len: 12
@ -33,9 +34,11 @@ model:
gpt_path: ./GPT-2 gpt_path: ./GPT-2
d_model: 64 d_model: 64
n_heads: 1 n_heads: 1
input_dim: 1
word_num: 1000
train: train:
batch_size: 64 batch_size: 32
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
epochs: 100 epochs: 100
@ -46,13 +49,10 @@ train:
lr_decay_step: "5,20,40,70" lr_decay_step: "5,20,40,70"
lr_init: 0.003 lr_init: 0.003
max_grad_norm: 5 max_grad_norm: 5
real_value: true
seed: 12
weight_decay: 0 weight_decay: 0
debug: false debug: false
output_dim: 1 output_dim: 1
log_step: 2000 log_step: 100
plot: false plot: false
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001

View File

@ -1,8 +1,9 @@
basic: basic:
dataset: "PEMSD8" dataset: "BJTaxi-outflow"
mode : "train" mode : "train"
device : "cuda:0" device : "cuda:0"
model: "REPST" model: "AEPSA"
seed: 2023
data: data:
add_day_in_week: true add_day_in_week: true
@ -13,14 +14,14 @@ data:
horizon: 12 horizon: 12
lag: 12 lag: 12
normalizer: std normalizer: std
num_nodes: 170 num_nodes: 142
steps_per_day: 288 steps_per_day: 48
test_ratio: 0.2 test_ratio: 0.2
tod: false tod: false
val_ratio: 0.2 val_ratio: 0.2
sample: 1 sample: 1
input_dim: 1 input_dim: 1
batch_size: 64 batch_size: 32
model: model:
pred_len: 12 pred_len: 12
@ -33,9 +34,11 @@ model:
gpt_path: ./GPT-2 gpt_path: ./GPT-2
d_model: 64 d_model: 64
n_heads: 1 n_heads: 1
input_dim: 1
word_num: 1000
train: train:
batch_size: 64 batch_size: 32
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
epochs: 100 epochs: 100
@ -46,13 +49,10 @@ train:
lr_decay_step: "5,20,40,70" lr_decay_step: "5,20,40,70"
lr_init: 0.003 lr_init: 0.003
max_grad_norm: 5 max_grad_norm: 5
real_value: true
seed: 12
weight_decay: 0 weight_decay: 0
debug: false debug: false
output_dim: 1 output_dim: 1
log_step: 2000 log_step: 100
plot: false plot: false
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001

59
config/AEPSA/METR-LA.yaml Normal file
View File

@ -0,0 +1,59 @@
basic:
dataset: "METR-LA"
mode : "train"
device : "cuda:0"
model: "AEPSA"
seed: 2023
data:
add_day_in_week: true
add_time_in_day: true
column_wise: false
days_per_week: 7
default_graph: true
horizon: 24
lag: 24
normalizer: std
num_nodes: 207
steps_per_day: 288
test_ratio: 0.2
tod: false
val_ratio: 0.2
sample: 1
input_dim: 1
batch_size: 16
model:
pred_len: 24
seq_len: 24
patch_len: 6
stride: 7
dropout: 0.2
gpt_layers: 9
d_ff: 128
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
input_dim: 1
word_num: 1000
train:
batch_size: 16
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
real_value: true
weight_decay: 0
debug: false
output_dim: 1
log_step: 1000
plot: false
mae_thresh: None
mape_thresh: 0.001

View File

@ -0,0 +1,58 @@
basic:
dataset: "NYCBike-inflow"
mode : "train"
device : "cuda:0"
model: "AEPSA"
seed: 2023
data:
add_day_in_week: true
add_time_in_day: true
column_wise: false
days_per_week: 7
default_graph: true
horizon: 12
lag: 12
normalizer: std
num_nodes: 200
steps_per_day: 24
test_ratio: 0.2
tod: false
val_ratio: 0.2
sample: 1
input_dim: 1
batch_size: 32
model:
pred_len: 12
seq_len: 12
patch_len: 6
stride: 7
dropout: 0.2
gpt_layers: 9
d_ff: 128
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
input_dim: 1
word_num: 1000
train:
batch_size: 32
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
weight_decay: 0
debug: false
output_dim: 1
log_step: 100
plot: false
mae_thresh: None
mape_thresh: 0.001

View File

@ -0,0 +1,58 @@
basic:
dataset: "NYCBike-outflow"
mode : "train"
device : "cuda:0"
model: "AEPSA"
seed: 2023
data:
add_day_in_week: true
add_time_in_day: true
column_wise: false
days_per_week: 7
default_graph: true
horizon: 12
lag: 12
normalizer: std
num_nodes: 200
steps_per_day: 24
test_ratio: 0.2
tod: false
val_ratio: 0.2
sample: 1
input_dim: 1
batch_size: 32
model:
pred_len: 12
seq_len: 12
patch_len: 6
stride: 7
dropout: 0.2
gpt_layers: 9
d_ff: 128
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
input_dim: 1
word_num: 1000
train:
batch_size: 32
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
weight_decay: 0
debug: false
output_dim: 1
log_step: 100
plot: false
mae_thresh: None
mape_thresh: 0.001

View File

@ -3,6 +3,7 @@ basic:
mode : "train" mode : "train"
device : "cuda:0" device : "cuda:0"
model: "AEPSA" model: "AEPSA"
seed: 2023
data: data:
add_day_in_week: true add_day_in_week: true
@ -34,6 +35,7 @@ model:
d_model: 64 d_model: 64
n_heads: 1 n_heads: 1
input_dim: 1 input_dim: 1
word_num: 1000
train: train:
batch_size: 16 batch_size: 16
@ -47,8 +49,6 @@ train:
lr_decay_step: "5,20,40,70" lr_decay_step: "5,20,40,70"
lr_init: 0.003 lr_init: 0.003
max_grad_norm: 5 max_grad_norm: 5
real_value: true
seed: 12
weight_decay: 0 weight_decay: 0
debug: false debug: false
output_dim: 1 output_dim: 1

View File

@ -0,0 +1,59 @@
basic:
dataset: "SolarEnergy"
mode : "train"
device : "cuda:0"
model: "AEPSA"
seed: 2023
data:
add_day_in_week: true
add_time_in_day: true
column_wise: false
days_per_week: 7
default_graph: true
horizon: 24
lag: 24
normalizer: std
num_nodes: 137
steps_per_day: 24
test_ratio: 0.2
tod: false
val_ratio: 0.2
sample: 1
input_dim: 1
batch_size: 64
model:
pred_len: 24
seq_len: 24
patch_len: 6
stride: 7
dropout: 0.2
gpt_layers: 9
d_ff: 128
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
input_dim: 1
word_num: 1000
num_nodes: 137
train:
batch_size: 64
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
weight_decay: 0
debug: false
output_dim: 1
log_step: 100
plot: false
mae_thresh: None
mape_thresh: 0.001

View File

@ -37,6 +37,7 @@ model:
input_dim: 6 input_dim: 6
output_dim: 3 output_dim: 3
word_num: 1000 word_num: 1000
num_nodes: 35
train: train:
batch_size: 16 batch_size: 16

View File

@ -36,6 +36,7 @@ model:
n_heads: 1 n_heads: 1
input_dim: 1 input_dim: 1
word_num: 1000 word_num: 1000
num_nodes: 1024
train: train:
batch_size: 16 batch_size: 16

View File

@ -36,6 +36,7 @@ model:
n_heads: 1 n_heads: 1
input_dim: 1 input_dim: 1
word_num: 1000 word_num: 1000
num_nodes: 1024
train: train:
batch_size: 16 batch_size: 16

View File

@ -37,6 +37,7 @@ model:
input_dim: 3 input_dim: 3
output_dim: 3 output_dim: 3
word_num: 1000 word_num: 1000
num_nodes: 7
train: train:
batch_size: 16 batch_size: 16

View File

@ -36,6 +36,7 @@ model:
n_heads: 1 n_heads: 1
input_dim: 1 input_dim: 1
word_num: 1000 word_num: 1000
num_nodes: 207
train: train:
batch_size: 16 batch_size: 16

View File

@ -36,6 +36,7 @@ model:
n_heads: 1 n_heads: 1
input_dim: 1 input_dim: 1
word_num: 1000 word_num: 1000
num_nodes: 128
train: train:
batch_size: 16 batch_size: 16

View File

@ -36,6 +36,7 @@ model:
n_heads: 1 n_heads: 1
input_dim: 1 input_dim: 1
word_num: 1000 word_num: 1000
num_nodes: 128
train: train:
batch_size: 16 batch_size: 16

View File

@ -36,6 +36,7 @@ model:
n_heads: 1 n_heads: 1
input_dim: 1 input_dim: 1
word_num: 1000 word_num: 1000
num_nodes: 325
train: train:
batch_size: 16 batch_size: 16

View File

@ -35,6 +35,7 @@ model:
n_heads: 1 n_heads: 1
input_dim: 1 input_dim: 1
t_max: 5 t_max: 5
num_nodes: 325
train: train:
batch_size: 16 batch_size: 16

View File

@ -21,7 +21,7 @@ data:
val_ratio: 0.2 val_ratio: 0.2
sample: 1 sample: 1
input_dim: 1 input_dim: 1
batch_size: 16 batch_size: 64
model: model:
pred_len: 24 pred_len: 24
@ -36,9 +36,10 @@ model:
n_heads: 1 n_heads: 1
input_dim: 1 input_dim: 1
word_num: 1000 word_num: 1000
num_nodes: 137
train: train:
batch_size: 16 batch_size: 64
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
epochs: 100 epochs: 100

251
model/AEPSA/aepsa.py Normal file
View File

@ -0,0 +1,251 @@
import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange
from model.AEPSA.normalizer import GumbelSoftmax
from model.AEPSA.reprogramming import PatchEmbedding, ReprogrammingLayer
import torch.nn.functional as F
class DynamicGraphEnhancer(nn.Module):
"""
动态图增强器基于节点嵌入自动生成图结构
参考DDGCRN的设计使用节点嵌入和特征信息动态计算邻接矩阵
"""
def __init__(self, num_nodes, in_dim, embed_dim=10):
super().__init__()
self.num_nodes = num_nodes
self.embed_dim = embed_dim
# 节点嵌入参数
self.node_embeddings = nn.Parameter(
torch.randn(num_nodes, embed_dim), requires_grad=True
)
# 特征转换层,用于生成动态调整的嵌入
self.feature_transform = nn.Sequential(
nn.Linear(in_dim, 16),
nn.Sigmoid(),
nn.Linear(16, 2),
nn.Sigmoid(),
nn.Linear(2, embed_dim)
)
# 注册单位矩阵作为固定的支持矩阵
self.register_buffer("eye", torch.eye(num_nodes))
def get_laplacian(self, graph, I, normalize=True):
"""
计算归一化拉普拉斯矩阵
"""
# 计算度矩阵的逆平方根
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题
if normalize:
return torch.matmul(torch.matmul(D_inv, graph), D_inv)
else:
return torch.matmul(torch.matmul(D_inv, graph + I), D_inv)
def forward(self, X):
"""
X: 输入特征 [B, N, D]
返回: 动态生成的归一化拉普拉斯矩阵 [B, N, N]
"""
batch_size = X.size(0)
laplacians = []
# 获取单位矩阵
I = self.eye.to(X.device)
for b in range(batch_size):
# 使用特征转换层生成动态嵌入调整因子
filt = self.feature_transform(X[b]) # [N, embed_dim]
# 计算节点嵌入向量
nodevec = torch.tanh(self.node_embeddings * filt)
# 通过节点嵌入的点积计算邻接矩阵
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1)))
# 计算归一化拉普拉斯矩阵
laplacian = self.get_laplacian(adj, I)
laplacians.append(laplacian)
return torch.stack(laplacians, dim=0)
class GraphEnhancedEncoder(nn.Module):
"""
基于Chebyshev多项式和动态拉普拉斯矩阵的图增强编码器
用于将动态图结构信息整合到特征编码中
"""
def __init__(self, K=3, in_dim=64, hidden_dim=32, num_nodes=325, embed_dim=10, device='cpu'):
super().__init__()
self.K = K # Chebyshev多项式阶数
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.device = device
# 动态图增强器
self.graph_enhancer = DynamicGraphEnhancer(num_nodes, in_dim, embed_dim)
# 谱系数 α_k (可学习参数)
self.alpha = nn.Parameter(torch.randn(K + 1, 1))
# 传播权重 W_k (可学习参数)
self.W = nn.ParameterList([
nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)
])
self.to(device)
def chebyshev_polynomials(self, L_tilde, X):
"""递归计算 [T_0(L_tilde)X, ..., T_K(L_tilde)X]"""
T_k_list = [X]
if self.K >= 1:
T_k_list.append(torch.matmul(L_tilde, X))
for k in range(2, self.K + 1):
T_k_list.append(2 * torch.matmul(L_tilde, T_k_list[-1]) - T_k_list[-2])
return T_k_list
def forward(self, X):
"""
X: 输入特征 [B, N, D]
返回: 增强后的特征 [B, N, hidden_dim*(K+1)]
"""
batch_size, num_nodes, _ = X.shape
enhanced_features = []
# 动态生成拉普拉斯矩阵
laplacians = self.graph_enhancer(X)
for b in range(batch_size):
L = laplacians[b]
# 特征值缩放
try:
lambda_max = torch.linalg.eigvalsh(L).max().real
# 避免除零问题
if lambda_max < 1e-6:
lambda_max = 1.0
L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device)
except:
# 如果计算特征值失败,使用单位矩阵
L_tilde = torch.eye(num_nodes, device=X.device)
# 计算Chebyshev多项式展开
T_k_list = self.chebyshev_polynomials(L_tilde, X[b])
H_list = []
# 应用传播权重
for k in range(self.K + 1):
H_k = torch.matmul(T_k_list[k], self.W[k])
H_list.append(H_k)
# 拼接所有尺度的特征
X_enhanced = torch.cat(H_list, dim=-1) # [N, hidden_dim*(K+1)]
enhanced_features.append(X_enhanced)
return torch.stack(enhanced_features, dim=0)
class AEPSA(nn.Module):
def __init__(self, configs):
super(AEPSA, self).__init__()
self.device = configs['device']
self.pred_len = configs['pred_len']
self.seq_len = configs['seq_len']
self.patch_len = configs['patch_len']
self.input_dim = configs['input_dim']
self.stride = configs['stride']
self.dropout = configs['dropout']
self.gpt_layers = configs['gpt_layers']
self.d_ff = configs['d_ff']
self.gpt_path = configs['gpt_path']
self.num_nodes = configs.get('num_nodes', 325) # 添加节点数量配置
self.word_choice = GumbelSoftmax(configs['word_num'])
self.d_model = configs['d_model']
self.n_heads = configs['n_heads']
self.d_keys = None
self.d_llm = 768
self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2)
self.head_nf = self.d_ff * self.patch_nums
# 词嵌入
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
# GPT2初始化
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
self.gpts.h = self.gpts.h[:self.gpt_layers]
self.gpts.apply(self.reset_parameters)
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
self.vocab_size = self.word_embeddings.shape[0]
self.mapping_layer = nn.Linear(self.vocab_size, 1)
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
# 添加动态图增强编码器
self.graph_encoder = GraphEnhancedEncoder(
K=configs.get('chebyshev_order', 3),
in_dim=self.d_model,
hidden_dim=configs.get('graph_hidden_dim', 32),
num_nodes=self.num_nodes,
embed_dim=configs.get('graph_embed_dim', 10),
device=self.device
)
# 特征融合层
self.feature_fusion = nn.Linear(
self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1),
self.d_model
)
self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128),
nn.ReLU(),
nn.Linear(128, self.pred_len)
)
for i, (name, param) in enumerate(self.gpts.named_parameters()):
if 'wpe' in name:
param.requires_grad = True
else:
param.requires_grad = False
def reset_parameters(self, module):
if hasattr(module, 'weight') and module.weight is not None:
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
torch.nn.init.zeros_(module.bias)
def forward(self, x):
"""
x: 输入数据 [B, T, N, C]
自动生成图结构无需外部提供邻接矩阵
"""
x = x[..., :1]
x_enc = rearrange(x, 'b t n c -> b n c t')
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C)
# 应用图增强编码器(自动生成图结构)
graph_enhanced = self.graph_encoder(enc_out)
# 保持相同的维度
# 特征融合 - 现在两个张量都是三维的 [B, N, d_model]
enc_out = torch.cat([enc_out, graph_enhanced], dim=-1)
enc_out = self.feature_fusion(enc_out)
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0))
source_embeddings = self.word_embeddings[masks==1]
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
dec_out = self.out_mlp(enc_out)
outputs = dec_out.unsqueeze(dim=-1)
outputs = outputs.repeat(1, 1, 1, n_vars)
outputs = outputs.permute(0,2,1,3)
return outputs

View File

@ -1,103 +0,0 @@
import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange
from model.REPST.normalizer import GumbelSoftmax
from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer
class repst(nn.Module):
def __init__(self, configs):
super(repst, self).__init__()
self.device = configs['device']
self.pred_len = configs['pred_len']
self.seq_len = configs['seq_len']
self.patch_len = configs['patch_len']
self.input_dim = configs['input_dim']
self.stride = configs['stride']
self.dropout = configs['dropout']
self.gpt_layers = configs['gpt_layers']
self.d_ff = configs['d_ff']
self.gpt_path = configs['gpt_path']
self.word_choice = GumbelSoftmax(configs['word_num'])
self.d_model = configs['d_model']
self.n_heads = configs['n_heads']
self.d_keys = None
self.d_llm = 768
self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2)
self.head_nf = self.d_ff * self.patch_nums
# 词嵌入
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
# GPT2初始化
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
self.gpts.h = self.gpts.h[:self.gpt_layers]
self.gpts.apply(self.reset_parameters)
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
self.vocab_size = self.word_embeddings.shape[0]
self.mapping_layer = nn.Linear(self.vocab_size, 1)
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128),
nn.ReLU(),
nn.Linear(128, self.pred_len)
)
for i, (name, param) in enumerate(self.gpts.named_parameters()):
if 'wpe' in name:
param.requires_grad = True
else:
param.requires_grad = False
def reset_parameters(self, module):
if hasattr(module, 'weight') and module.weight is not None:
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
torch.nn.init.zeros_(module.bias)
def forward(self, x):
x = x[..., :1]
x_enc = rearrange(x, 'b t n c -> b n c t')
enc_out, n_vars = self.patch_embedding(x_enc)
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0))
source_embeddings = self.word_embeddings[masks==1]
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
dec_out = self.out_mlp(enc_out)
outputs = dec_out.unsqueeze(dim=-1)
outputs = outputs.repeat(1, 1, 1, n_vars)
outputs = outputs.permute(0,2,1,3)
return outputs
if __name__ == '__main__':
configs = {
'device': 'cuda:0',
'pred_len': 24,
'seq_len': 24,
'patch_len': 6,
'stride': 7,
'dropout': 0.2,
'gpt_layers': 9,
'd_ff': 128,
'gpt_path': './GPT-2',
'd_model': 64,
'n_heads': 1,
'input_dim': 1
}
model = repst(configs)
x = torch.randn(16, 24, 325, 1)
y = model(x)
print(y.shape)

View File

@ -23,7 +23,7 @@ from model.ST_SSL.ST_SSL import STSSLModel
from model.STGNRDE.Make_model import make_model as make_nrde_model from model.STGNRDE.Make_model import make_model as make_nrde_model
from model.STAWnet.STAWnet import STAWnet from model.STAWnet.STAWnet import STAWnet
from model.REPST.repst import repst as REPST from model.REPST.repst import repst as REPST
from model.AEPSA.repst import repst as AEPSA from model.AEPSA.aepsa import AEPSA as AEPSA
def model_selector(config): def model_selector(config):

View File

@ -160,6 +160,7 @@ def check_and_download_data():
file_path = f"Datasets/TaxiBJ/{file}" file_path = f"Datasets/TaxiBJ/{file}"
download_github_data(file_path, taxi_bj_floder) download_github_data(file_path, taxi_bj_floder)
# 下载后更新缺失列表 # 下载后更新缺失列表
# download_and_extract("http://code.zhang-heng.com/static/BeijingTaxi.7z", data_dir)
missing_list = detect_data_integrity(data_dir, file_tree) missing_list = detect_data_integrity(data_dir, file_tree)
# 检查并下载TaxiBJ数据 # 检查并下载TaxiBJ数据