This commit is contained in:
czzhangheng 2025-12-06 19:47:33 +08:00
parent 07d7d43857
commit 865c5a3082
5 changed files with 331 additions and 0 deletions

8
.vscode/launch.json vendored
View File

@ -234,6 +234,14 @@
"console": "integratedTerminal", "console": "integratedTerminal",
"args": "--config ./config/AEPSA/v2_SolarEnergy.yaml" "args": "--config ./config/AEPSA/v2_SolarEnergy.yaml"
}, },
{
"name": "AEPSA_v3: METR-LA",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/AEPSA/v3_METR-LA.yaml"
},
{ {
"name": "EXPB: NYCBike-InFlow", "name": "EXPB: NYCBike-InFlow",
"type": "debugpy", "type": "debugpy",

View File

@ -0,0 +1,57 @@
basic:
dataset: METR-LA
device: cuda:0
mode: train
model: AEPSA_v3
seed: 2023
data:
batch_size: 16
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 207
steps_per_day: 288
test_ratio: 0.2
val_ratio: 0.2
model:
chebyshev_order: 3
d_ff: 128
d_model: 64
dropout: 0.2
graph_hidden_dim: 32
gpt_layers: 9
gpt_path: ./GPT-2
input_dim: 1
n_heads: 1
num_nodes: 207
patch_len: 6
pred_len: 24
seq_len: 24
stride: 7
word_num: 1000
train:
batch_size: 16
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 1000
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
real_value: true
weight_decay: 0

54
config/AEPSA/v3_PEMS-BAY.yaml Executable file
View File

@ -0,0 +1,54 @@
basic:
dataset: PEMS-BAY
device: cuda:0
mode: train
model: AEPSA_v3
seed: 2023
data:
batch_size: 16
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 325
steps_per_day: 288
test_ratio: 0.2
val_ratio: 0.2
model:
d_ff: 128
d_model: 64
dropout: 0.2
gpt_layers: 9
gpt_path: ./GPT-2
input_dim: 1
n_heads: 1
num_nodes: 325
patch_len: 6
pred_len: 24
seq_len: 24
stride: 7
word_num: 1000
train:
batch_size: 16
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 100
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
weight_decay: 0

209
model/AEPSA/aepsav3.py Normal file
View File

@ -0,0 +1,209 @@
import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange
from model.AEPSA.normalizer import GumbelSoftmax
from model.AEPSA.reprogramming import ReprogrammingLayer
import torch.nn.functional as F
# 基于动态图增强的时空序列预测模型实现
class DynamicGraphEnhancer(nn.Module):
"""动态图增强编码器"""
def __init__(self, num_nodes, in_dim, embed_dim=10):
super().__init__()
self.num_nodes = num_nodes # 节点个数
self.embed_dim = embed_dim # 节点嵌入维度
self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) # 节点嵌入参数
self.feature_transform = nn.Sequential( # 特征转换网络
nn.Linear(in_dim, 16),
nn.Sigmoid(),
nn.Linear(16, 2),
nn.Sigmoid(),
nn.Linear(2, embed_dim)
)
self.register_buffer("eye", torch.eye(num_nodes)) # 注册单位矩阵
def get_laplacian(self, graph, I, normalize=True):
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) # 度矩阵的逆平方根
D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题
if normalize:
return torch.matmul(torch.matmul(D_inv, graph), D_inv) # 归一化拉普拉斯矩阵
else:
return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) # 带自环的归一化拉普拉斯矩阵
def forward(self, X):
"""生成动态拉普拉斯矩阵"""
batch_size = X.size(0) # 批次大小
laplacians = [] # 存储各批次的拉普拉斯矩阵
I = self.eye.to(X.device) # 移动单位矩阵到目标设备
for b in range(batch_size):
filt = self.feature_transform(X[b]) # 特征转换
nodevec = torch.tanh(self.node_embeddings * filt) # 计算节点嵌入
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # 计算邻接矩阵
laplacian = self.get_laplacian(adj, I) # 计算拉普拉斯矩阵
laplacians.append(laplacian)
return torch.stack(laplacians, dim=0) # 堆叠并返回
class GraphEnhancedEncoder(nn.Module):
"""图增强编码器"""
def __init__(self, K=3, in_dim=64, hidden_dim=32, num_nodes=325, embed_dim=10, device='cpu',
temporal_dim=12, num_features=1):
super().__init__()
self.K = K # Chebyshev多项式阶数
self.in_dim = in_dim # 输入特征维度
self.hidden_dim = hidden_dim # 隐藏层维度
self.device = device # 运行设备
self.temporal_dim = temporal_dim # 时间序列长度
self.num_features = num_features # 特征通道数量
self.input_projection = nn.Sequential( # 输入投影层
nn.Conv2d(num_features, 16, kernel_size=(1, 3), padding=(0, 1)),
nn.ReLU(),
nn.Conv2d(16, in_dim, kernel_size=(1, temporal_dim)),
nn.ReLU()
)
self.graph_enhancer = DynamicGraphEnhancer(num_nodes, in_dim, embed_dim) # 动态图增强器
self.alpha = nn.Parameter(torch.randn(K + 1, 1)) # 谱系数
self.W = nn.ParameterList([nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)]) # 传播权重
self.to(device) # 移动到指定设备
def chebyshev_polynomials(self, L_tilde, X):
"""计算Chebyshev多项式展开"""
T_k_list = [X] # T_0(X) = X
if self.K >= 1:
T_k_list.append(torch.matmul(L_tilde, X)) # T_1(X) = L_tilde * X
for k in range(2, self.K + 1):
T_k_list.append(2 * torch.matmul(L_tilde, T_k_list[-1]) - T_k_list[-2]) # 递推计算
return T_k_list # 返回多项式列表
def forward(self, X):
"""输入特征[B,N,C,T],返回增强特征[B,N,hidden_dim*(K+1)]"""
batch_size = X.size(0) # 批次大小
num_nodes = X.size(1) # 节点数量
x = X.permute(0, 2, 1, 3) # [B,C,N,T]
x_proj = self.input_projection(x).squeeze(-1) # [B,in_dim,N]
x_proj = x_proj.permute(0, 2, 1) # [B,N,in_dim]
enhanced_features = [] # 存储增强特征
laplacians = self.graph_enhancer(x_proj) # 生成动态拉普拉斯矩阵
for b in range(batch_size):
L = laplacians[b] # 当前批次的拉普拉斯矩阵
# 特征值缩放
try:
lambda_max = torch.linalg.eigvalsh(L).max().real # 最大特征值
lambda_max = 1.0 if lambda_max < 1e-6 else lambda_max # 防止除零
L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device) # 归一化拉普拉斯
except:
L_tilde = torch.eye(num_nodes, device=X.device) # 异常处理
# 计算展开并应用权重
T_k_list = self.chebyshev_polynomials(L_tilde, x_proj[b]) # 计算Chebyshev多项式
H_list = [torch.matmul(T_k_list[k], self.W[k]) for k in range(self.K + 1)] # 应用权重
X_enhanced = torch.cat(H_list, dim=-1) # 拼接特征
enhanced_features.append(X_enhanced)
return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)]每个节点在每个k阶下的切比雪夫特征
class AEPSA(nn.Module):
"""自适应特征投影时空自注意力模型"""
def __init__(self, configs):
super(AEPSA, self).__init__()
self.device = configs['device'] # 运行设备
self.pred_len = configs['pred_len'] # 预测序列长度
self.seq_len = configs['seq_len'] # 输入序列长度
self.patch_len = configs['patch_len'] # 补丁长度
self.input_dim = configs['input_dim'] # 输入特征维度
self.stride = configs['stride'] # 步长
self.dropout = configs['dropout'] # Dropout概率
self.gpt_layers = configs['gpt_layers'] # 使用的GPT2层数
self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度
self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径
self.num_nodes = configs.get('num_nodes', 325) # 节点数量
self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层
self.d_model = configs['d_model'] # 模型维度
self.n_heads = configs['n_heads'] # 注意力头数量
self.d_keys = None # 键维度
self.d_llm = 768 # GPT2隐藏层维度
self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2) # 补丁数量
self.head_nf = self.d_ff * self.patch_nums # 头特征维度
# 初始化GPT2模型
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) # GPT2模型
self.gpts.h = self.gpts.h[:self.gpt_layers] # 截取指定层数
self.gpts.apply(self.reset_parameters) # 重置参数
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) # 词嵌入权重
self.vocab_size = self.word_embeddings.shape[0] # 词汇表大小
self.mapping_layer = nn.Linear(self.vocab_size, 1) # 映射层
self.reprogramming_layer = ReprogrammingLayer(self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), self.n_heads, self.d_keys, self.d_llm) # 重编程层
# 初始化图增强编码器
self.graph_encoder = GraphEnhancedEncoder(
K=configs.get('chebyshev_order', 3), # Chebyshev多项式阶数
in_dim=self.d_model, # 输入特征维度
hidden_dim=configs.get('graph_hidden_dim', 32), # 隐藏层维度
num_nodes=self.num_nodes, # 节点数量
embed_dim=configs.get('graph_embed_dim', 10), # 节点嵌入维度
device=self.device, # 运行设备
temporal_dim=self.seq_len, # 时间序列长度
num_features=self.input_dim # 特征通道数
)
self.graph_projection = nn.Linear( # 图特征投影层每一k阶的切比雪夫权重映射到隐藏维度
configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), # 输入维度
self.d_model # 输出维度
)
self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128),
nn.ReLU(),
nn.Linear(128, self.pred_len)
)
# 设置参数可训练性 wps=word position embeddings
for name, param in self.gpts.named_parameters():
param.requires_grad = 'wpe' in name
def reset_parameters(self, module):
if hasattr(module, 'weight') and module.weight is not None:
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
torch.nn.init.zeros_(module.bias)
def forward(self, x):
# 数据处理
x = x[..., :1] # [B,T,N,1]
x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T]
# 图编码
H_t = self.graph_encoder(x_enc) # [B,N,1,T] -> [B, N, hidden_dim*(K+1)]
X_t_1 = self.graph_projection(H_t) # [B,N,d_model]
enc_out = torch.cat([H_t, X_t_1], dim = -1) # [B, N, d_model + hidden_dim*(K+1)]
# 词嵌入处理
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0)) # [d_llm,1]
source_embeddings = self.word_embeddings[masks==1] # [selected_words,d_llm]
# 重编程与预测
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state # [B,N,d_llm]
dec_out = self.out_mlp(enc_out) # [B,N,pred_len]
# 维度调整
outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1]
outputs = outputs.permute(0, 2, 1, 3) # [B,pred_len,N,1]
return outputs

View File

@ -25,6 +25,7 @@ from model.STAWnet.STAWnet import STAWnet
from model.REPST.repst import repst as REPST from model.REPST.repst import repst as REPST
from model.AEPSA.aepsa import AEPSA as AEPSA from model.AEPSA.aepsa import AEPSA as AEPSA
from model.AEPSA.aepsav2 import AEPSA as AEPSAv2 from model.AEPSA.aepsav2 import AEPSA as AEPSAv2
from model.AEPSA.aepsav3 import AEPSA as AEPSAv3
@ -86,3 +87,5 @@ def model_selector(config):
return AEPSA(model_config) return AEPSA(model_config)
case "AEPSA_v2": case "AEPSA_v2":
return AEPSAv2(model_config) return AEPSAv2(model_config)
case "AEPSA_v3":
return AEPSAv3(model_config)