TrafficWheel/model/AEPSA/aepsav3.py

209 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange
from model.AEPSA.normalizer import GumbelSoftmax
from model.AEPSA.reprogramming import ReprogrammingLayer
import torch.nn.functional as F
# 基于动态图增强的时空序列预测模型实现
class DynamicGraphEnhancer(nn.Module):
"""动态图增强编码器"""
def __init__(self, num_nodes, in_dim, embed_dim=10):
super().__init__()
self.num_nodes = num_nodes # 节点个数
self.embed_dim = embed_dim # 节点嵌入维度
self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) # 节点嵌入参数
self.feature_transform = nn.Sequential( # 特征转换网络
nn.Linear(in_dim, 16),
nn.Sigmoid(),
nn.Linear(16, 2),
nn.Sigmoid(),
nn.Linear(2, embed_dim)
)
self.register_buffer("eye", torch.eye(num_nodes)) # 注册单位矩阵
def get_laplacian(self, graph, I, normalize=True):
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) # 度矩阵的逆平方根
D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题
if normalize:
return torch.matmul(torch.matmul(D_inv, graph), D_inv) # 归一化拉普拉斯矩阵
else:
return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) # 带自环的归一化拉普拉斯矩阵
def forward(self, X):
"""生成动态拉普拉斯矩阵"""
batch_size = X.size(0) # 批次大小
laplacians = [] # 存储各批次的拉普拉斯矩阵
I = self.eye.to(X.device) # 移动单位矩阵到目标设备
for b in range(batch_size):
filt = self.feature_transform(X[b]) # 特征转换
nodevec = torch.tanh(self.node_embeddings * filt) # 计算节点嵌入
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # 计算邻接矩阵
laplacian = self.get_laplacian(adj, I) # 计算拉普拉斯矩阵
laplacians.append(laplacian)
return torch.stack(laplacians, dim=0) # 堆叠并返回
class GraphEnhancedEncoder(nn.Module):
"""图增强编码器"""
def __init__(self, K=3, in_dim=64, hidden_dim=32, num_nodes=325, embed_dim=10, device='cpu',
temporal_dim=12, num_features=1):
super().__init__()
self.K = K # Chebyshev多项式阶数
self.in_dim = in_dim # 输入特征维度
self.hidden_dim = hidden_dim # 隐藏层维度
self.device = device # 运行设备
self.temporal_dim = temporal_dim # 时间序列长度
self.num_features = num_features # 特征通道数量
self.input_projection = nn.Sequential( # 输入投影层
nn.Conv2d(num_features, 16, kernel_size=(1, 3), padding=(0, 1)),
nn.ReLU(),
nn.Conv2d(16, in_dim, kernel_size=(1, temporal_dim)),
nn.ReLU()
)
self.graph_enhancer = DynamicGraphEnhancer(num_nodes, in_dim, embed_dim) # 动态图增强器
self.alpha = nn.Parameter(torch.randn(K + 1, 1)) # 谱系数
self.W = nn.ParameterList([nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)]) # 传播权重
self.to(device) # 移动到指定设备
def chebyshev_polynomials(self, L_tilde, X):
"""计算Chebyshev多项式展开"""
T_k_list = [X] # T_0(X) = X
if self.K >= 1:
T_k_list.append(torch.matmul(L_tilde, X)) # T_1(X) = L_tilde * X
for k in range(2, self.K + 1):
T_k_list.append(2 * torch.matmul(L_tilde, T_k_list[-1]) - T_k_list[-2]) # 递推计算
return T_k_list # 返回多项式列表
def forward(self, X):
"""输入特征[B,N,C,T],返回增强特征[B,N,hidden_dim*(K+1)]"""
batch_size = X.size(0) # 批次大小
num_nodes = X.size(1) # 节点数量
x = X.permute(0, 2, 1, 3) # [B,C,N,T]
x_proj = self.input_projection(x).squeeze(-1) # [B,in_dim,N]
x_proj = x_proj.permute(0, 2, 1) # [B,N,in_dim]
enhanced_features = [] # 存储增强特征
laplacians = self.graph_enhancer(x_proj) # 生成动态拉普拉斯矩阵
for b in range(batch_size):
L = laplacians[b] # 当前批次的拉普拉斯矩阵
# 特征值缩放
try:
lambda_max = torch.linalg.eigvalsh(L).max().real # 最大特征值
lambda_max = 1.0 if lambda_max < 1e-6 else lambda_max # 防止除零
L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device) # 归一化拉普拉斯
except:
L_tilde = torch.eye(num_nodes, device=X.device) # 异常处理
# 计算展开并应用权重
T_k_list = self.chebyshev_polynomials(L_tilde, x_proj[b]) # 计算Chebyshev多项式
H_list = [torch.matmul(T_k_list[k], self.W[k]) for k in range(self.K + 1)] # 应用权重
X_enhanced = torch.cat(H_list, dim=-1) # 拼接特征
enhanced_features.append(X_enhanced)
return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)]每个节点在每个k阶下的切比雪夫特征
class AEPSA(nn.Module):
"""自适应特征投影时空自注意力模型"""
def __init__(self, configs):
super(AEPSA, self).__init__()
self.device = configs['device'] # 运行设备
self.pred_len = configs['pred_len'] # 预测序列长度
self.seq_len = configs['seq_len'] # 输入序列长度
self.patch_len = configs['patch_len'] # 补丁长度
self.input_dim = configs['input_dim'] # 输入特征维度
self.stride = configs['stride'] # 步长
self.dropout = configs['dropout'] # Dropout概率
self.gpt_layers = configs['gpt_layers'] # 使用的GPT2层数
self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度
self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径
self.num_nodes = configs.get('num_nodes', 325) # 节点数量
self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层
self.d_model = configs['d_model'] # 模型维度
self.n_heads = configs['n_heads'] # 注意力头数量
self.d_keys = None # 键维度
self.d_llm = 768 # GPT2隐藏层维度
self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2) # 补丁数量
self.head_nf = self.d_ff * self.patch_nums # 头特征维度
# 初始化GPT2模型
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) # GPT2模型
self.gpts.h = self.gpts.h[:self.gpt_layers] # 截取指定层数
self.gpts.apply(self.reset_parameters) # 重置参数
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) # 词嵌入权重
self.vocab_size = self.word_embeddings.shape[0] # 词汇表大小
self.mapping_layer = nn.Linear(self.vocab_size, 1) # 映射层
self.reprogramming_layer = ReprogrammingLayer(self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), self.n_heads, self.d_keys, self.d_llm) # 重编程层
# 初始化图增强编码器
self.graph_encoder = GraphEnhancedEncoder(
K=configs.get('chebyshev_order', 3), # Chebyshev多项式阶数
in_dim=self.d_model, # 输入特征维度
hidden_dim=configs.get('graph_hidden_dim', 32), # 隐藏层维度
num_nodes=self.num_nodes, # 节点数量
embed_dim=configs.get('graph_embed_dim', 10), # 节点嵌入维度
device=self.device, # 运行设备
temporal_dim=self.seq_len, # 时间序列长度
num_features=self.input_dim # 特征通道数
)
self.graph_projection = nn.Linear( # 图特征投影层每一k阶的切比雪夫权重映射到隐藏维度
configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), # 输入维度
self.d_model # 输出维度
)
self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128),
nn.ReLU(),
nn.Linear(128, self.pred_len)
)
# 设置参数可训练性 wps=word position embeddings
for name, param in self.gpts.named_parameters():
param.requires_grad = 'wpe' in name
def reset_parameters(self, module):
if hasattr(module, 'weight') and module.weight is not None:
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
torch.nn.init.zeros_(module.bias)
def forward(self, x):
# 数据处理
x = x[..., :1] # [B,T,N,1]
x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T]
# 图编码
H_t = self.graph_encoder(x_enc) # [B,N,1,T] -> [B, N, hidden_dim*(K+1)]
X_t_1 = self.graph_projection(H_t) # [B,N,d_model]
X_enc = torch.cat([H_t, X_t_1], dim = -1) # [B, N, d_model + hidden_dim*(K+1)]
# 词嵌入处理
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0)) # [d_llm,1]
source_embeddings = self.word_embeddings[masks==1] # [selected_words,d_llm]
# 重编程与预测
X_enc = self.reprogramming_layer(X_enc, source_embeddings, source_embeddings)
X_enc = self.gpts(inputs_embeds=X_enc).last_hidden_state # [B,N,d_llm]
dec_out = self.out_mlp(X_enc) # [B,N,pred_len]
# 维度调整
outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1]
outputs = outputs.permute(0, 2, 1, 3) # [B,pred_len,N,1]
return outputs