TrafficWheel/model/AEPSA/aepsa.py

252 lines
9.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange
from model.AEPSA.normalizer import GumbelSoftmax
from model.AEPSA.reprogramming import PatchEmbedding, ReprogrammingLayer
import torch.nn.functional as F
class DynamicGraphEnhancer(nn.Module):
"""
动态图增强器,基于节点嵌入自动生成图结构
参考DDGCRN的设计使用节点嵌入和特征信息动态计算邻接矩阵
"""
def __init__(self, num_nodes, in_dim, embed_dim=10):
super().__init__()
self.num_nodes = num_nodes
self.embed_dim = embed_dim
# 节点嵌入参数
self.node_embeddings = nn.Parameter(
torch.randn(num_nodes, embed_dim), requires_grad=True
)
# 特征转换层,用于生成动态调整的嵌入
self.feature_transform = nn.Sequential(
nn.Linear(in_dim, 16),
nn.Sigmoid(),
nn.Linear(16, 2),
nn.Sigmoid(),
nn.Linear(2, embed_dim)
)
# 注册单位矩阵作为固定的支持矩阵
self.register_buffer("eye", torch.eye(num_nodes))
def get_laplacian(self, graph, I, normalize=True):
"""
计算归一化拉普拉斯矩阵
"""
# 计算度矩阵的逆平方根
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题
if normalize:
return torch.matmul(torch.matmul(D_inv, graph), D_inv)
else:
return torch.matmul(torch.matmul(D_inv, graph + I), D_inv)
def forward(self, X):
"""
X: 输入特征 [B, N, D]
返回: 动态生成的归一化拉普拉斯矩阵 [B, N, N]
"""
batch_size = X.size(0)
laplacians = []
# 获取单位矩阵
I = self.eye.to(X.device)
for b in range(batch_size):
# 使用特征转换层生成动态嵌入调整因子
filt = self.feature_transform(X[b]) # [N, embed_dim]
# 计算节点嵌入向量
nodevec = torch.tanh(self.node_embeddings * filt)
# 通过节点嵌入的点积计算邻接矩阵
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1)))
# 计算归一化拉普拉斯矩阵
laplacian = self.get_laplacian(adj, I)
laplacians.append(laplacian)
return torch.stack(laplacians, dim=0)
class GraphEnhancedEncoder(nn.Module):
"""
基于Chebyshev多项式和动态拉普拉斯矩阵的图增强编码器
用于将动态图结构信息整合到特征编码中
"""
def __init__(self, K=3, in_dim=64, hidden_dim=32, num_nodes=325, embed_dim=10, device='cpu'):
super().__init__()
self.K = K # Chebyshev多项式阶数
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.device = device
# 动态图增强器
self.graph_enhancer = DynamicGraphEnhancer(num_nodes, in_dim, embed_dim)
# 谱系数 α_k (可学习参数)
self.alpha = nn.Parameter(torch.randn(K + 1, 1))
# 传播权重 W_k (可学习参数)
self.W = nn.ParameterList([
nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)
])
self.to(device)
def chebyshev_polynomials(self, L_tilde, X):
"""递归计算 [T_0(L_tilde)X, ..., T_K(L_tilde)X]"""
T_k_list = [X]
if self.K >= 1:
T_k_list.append(torch.matmul(L_tilde, X))
for k in range(2, self.K + 1):
T_k_list.append(2 * torch.matmul(L_tilde, T_k_list[-1]) - T_k_list[-2])
return T_k_list
def forward(self, X):
"""
X: 输入特征 [B, N, D]
返回: 增强后的特征 [B, N, hidden_dim*(K+1)]
"""
batch_size, num_nodes, _ = X.shape
enhanced_features = []
# 动态生成拉普拉斯矩阵
laplacians = self.graph_enhancer(X)
for b in range(batch_size):
L = laplacians[b]
# 特征值缩放
try:
lambda_max = torch.linalg.eigvalsh(L).max().real
# 避免除零问题
if lambda_max < 1e-6:
lambda_max = 1.0
L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device)
except:
# 如果计算特征值失败,使用单位矩阵
L_tilde = torch.eye(num_nodes, device=X.device)
# 计算Chebyshev多项式展开
T_k_list = self.chebyshev_polynomials(L_tilde, X[b])
H_list = []
# 应用传播权重
for k in range(self.K + 1):
H_k = torch.matmul(T_k_list[k], self.W[k])
H_list.append(H_k)
# 拼接所有尺度的特征
X_enhanced = torch.cat(H_list, dim=-1) # [N, hidden_dim*(K+1)]
enhanced_features.append(X_enhanced)
return torch.stack(enhanced_features, dim=0)
class AEPSA(nn.Module):
def __init__(self, configs):
super(AEPSA, self).__init__()
self.device = configs['device']
self.pred_len = configs['pred_len']
self.seq_len = configs['seq_len']
self.patch_len = configs['patch_len']
self.input_dim = configs['input_dim']
self.stride = configs['stride']
self.dropout = configs['dropout']
self.gpt_layers = configs['gpt_layers']
self.d_ff = configs['d_ff']
self.gpt_path = configs['gpt_path']
self.num_nodes = configs.get('num_nodes', 325) # 添加节点数量配置
self.output_dim = configs.get('output_dim', 1)
self.word_choice = GumbelSoftmax(configs['word_num'])
self.d_model = configs['d_model']
self.n_heads = configs['n_heads']
self.d_keys = None
self.d_llm = 768
self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2)
self.head_nf = self.d_ff * self.patch_nums
# 词嵌入
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim, self.output_dim)
# GPT2初始化
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
self.gpts.h = self.gpts.h[:self.gpt_layers]
self.gpts.apply(self.reset_parameters)
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
self.vocab_size = self.word_embeddings.shape[0]
self.mapping_layer = nn.Linear(self.vocab_size, 1)
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
# 添加动态图增强编码器
self.graph_encoder = GraphEnhancedEncoder(
K=configs.get('chebyshev_order', 3),
in_dim=self.d_model,
hidden_dim=configs.get('graph_hidden_dim', 32),
num_nodes=self.num_nodes,
embed_dim=configs.get('graph_embed_dim', 10),
device=self.device
)
# 特征融合层
self.feature_fusion = nn.Linear(
self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1),
self.d_model
)
self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128),
nn.ReLU(),
nn.Linear(128, self.pred_len)
)
for i, (name, param) in enumerate(self.gpts.named_parameters()):
if 'wpe' in name:
param.requires_grad = True
else:
param.requires_grad = False
def reset_parameters(self, module):
if hasattr(module, 'weight') and module.weight is not None:
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
torch.nn.init.zeros_(module.bias)
def forward(self, x):
"""
x: 输入数据 [B, T, N, C]
自动生成图结构,无需外部提供邻接矩阵
"""
x = x[..., :self.input_dim]
x_enc = rearrange(x, 'b t n c -> b n c t')
# 原版Patch
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C)
# 应用图增强编码器(自动生成图结构)
graph_enhanced = self.graph_encoder(enc_out)
# 特征融合 - 现在两个张量都是三维的 [B, N, d_model]
enc_out = torch.cat([enc_out, graph_enhanced], dim=-1)
enc_out = self.feature_fusion(enc_out)
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0))
source_embeddings = self.word_embeddings[masks==1]
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
dec_out = self.out_mlp(enc_out)
outputs = dec_out.unsqueeze(dim=-1)
outputs = outputs.repeat(1, 1, 1, n_vars)
outputs = outputs.permute(0,2,1,3)
return outputs