TrafficWheel/model/ASTRA/astra.py

227 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange
from model.ASTRA.normalizer import GumbelSoftmax
from model.ASTRA.reprogramming import PatchEmbedding, ReprogrammingLayer
import torch.nn.functional as F
class DynamicGraphEnhancer(nn.Module):
"""动态图增强编码器"""
def __init__(self, num_nodes, in_dim, embed_dim=10):
super().__init__()
self.num_nodes = num_nodes # 节点个数
self.embed_dim = embed_dim # 节点嵌入维度
self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) # 节点嵌入参数
self.feature_transform = nn.Sequential( # 特征转换网络
nn.Linear(in_dim, 16),
nn.Sigmoid(),
nn.Linear(16, 2),
nn.Sigmoid(),
nn.Linear(2, embed_dim)
)
self.register_buffer("eye", torch.eye(num_nodes)) # 注册单位矩阵
def get_laplacian(self, graph, I, normalize=True):
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) # 度矩阵的逆平方根
D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题
if normalize:
return torch.matmul(torch.matmul(D_inv, graph), D_inv) # 归一化拉普拉斯矩阵
else:
return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) # 带自环的归一化拉普拉斯矩阵
def forward(self, X):
"""生成动态拉普拉斯矩阵"""
batch_size = X.size(0) # 批次大小
laplacians = [] # 存储各批次的拉普拉斯矩阵
I = self.eye.to(X.device) # 移动单位矩阵到目标设备
for b in range(batch_size):
filt = self.feature_transform(X[b]) # 特征转换
nodevec = torch.tanh(self.node_embeddings * filt) # 计算节点嵌入
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # 计算邻接矩阵
laplacian = self.get_laplacian(adj, I) # 计算拉普拉斯矩阵
laplacians.append(laplacian)
return torch.stack(laplacians, dim=0) # 堆叠并返回
class GraphEnhancedEncoder(nn.Module):
"""
基于Chebyshev多项式和动态拉普拉斯矩阵的图增强编码器
用于将动态图结构信息整合到特征编码中
"""
def __init__(self, K=3, in_dim=64, hidden_dim=32, num_nodes=325, embed_dim=10, device='cpu'):
super().__init__()
self.K = K # Chebyshev多项式阶数
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.device = device
# 动态图增强器
self.graph_enhancer = DynamicGraphEnhancer(num_nodes, in_dim, embed_dim)
# 谱系数 α_k (可学习参数)
self.alpha = nn.Parameter(torch.randn(K + 1, 1))
# 传播权重 W_k (可学习参数)
self.W = nn.ParameterList([
nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)
])
self.to(device)
def chebyshev_polynomials(self, L_tilde, X):
"""递归计算 [T_0(L_tilde)X, ..., T_K(L_tilde)X]"""
T_k_list = [X]
if self.K >= 1:
T_k_list.append(torch.matmul(L_tilde, X))
for k in range(2, self.K + 1):
T_k_list.append(2 * torch.matmul(L_tilde, T_k_list[-1]) - T_k_list[-2])
return T_k_list
def forward(self, X):
"""
X: 输入特征 [B, N, D]
返回: 增强后的特征 [B, N, hidden_dim*(K+1)]
"""
batch_size, num_nodes, _ = X.shape
enhanced_features = []
# 动态生成拉普拉斯矩阵
laplacians = self.graph_enhancer(X)
for b in range(batch_size):
L = laplacians[b]
# 特征值缩放
try:
lambda_max = torch.linalg.eigvalsh(L).max().real
# 避免除零问题
if lambda_max < 1e-6:
lambda_max = 1.0
L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device)
except:
# 如果计算特征值失败,使用单位矩阵
L_tilde = torch.eye(num_nodes, device=X.device)
# 计算Chebyshev多项式展开
T_k_list = self.chebyshev_polynomials(L_tilde, X[b])
H_list = []
# 应用传播权重
for k in range(self.K + 1):
H_k = torch.matmul(T_k_list[k], self.W[k])
H_list.append(H_k)
# 拼接所有尺度的特征
X_enhanced = torch.cat(H_list, dim=-1) # [N, hidden_dim*(K+1)]
enhanced_features.append(X_enhanced)
return torch.stack(enhanced_features, dim=0)
class ASTRA(nn.Module):
def __init__(self, configs):
super(ASTRA, self).__init__()
self.device = configs['device']
self.pred_len = configs['pred_len']
self.seq_len = configs['seq_len']
self.patch_len = configs['patch_len']
self.input_dim = configs['input_dim']
self.stride = configs['stride']
self.dropout = configs['dropout']
self.gpt_layers = configs['gpt_layers']
self.d_ff = configs['d_ff']
self.gpt_path = configs['gpt_path']
self.num_nodes = configs.get('num_nodes', 325) # 添加节点数量配置
self.output_dim = configs.get('output_dim', 1)
self.word_choice = GumbelSoftmax(configs['word_num'])
self.d_model = configs['d_model']
self.n_heads = configs['n_heads']
self.d_keys = None
self.d_llm = 768
self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2)
self.head_nf = self.d_ff * self.patch_nums
# 词嵌入
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim, self.output_dim)
# GPT2初始化
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
self.gpts.h = self.gpts.h[:self.gpt_layers]
self.gpts.apply(self.reset_parameters)
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
self.vocab_size = self.word_embeddings.shape[0]
self.mapping_layer = nn.Linear(self.vocab_size, 1)
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
# 添加动态图增强编码器
self.graph_encoder = GraphEnhancedEncoder(
K=configs.get('chebyshev_order', 3),
in_dim=self.d_model * self.input_dim,
hidden_dim=self.d_model,
num_nodes=self.num_nodes,
embed_dim=configs.get('graph_embed_dim', 10),
device=self.device
)
# 特征融合层
self.feature_fusion = nn.Linear(
self.d_model * self.input_dim + self.d_model * (configs.get('chebyshev_order', 3) + 1),
self.d_model
)
self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128),
nn.ReLU(),
nn.Linear(128, self.pred_len * self.output_dim)
)
for i, (name, param) in enumerate(self.gpts.named_parameters()):
if 'wpe' in name:
param.requires_grad = True
else:
param.requires_grad = False
def reset_parameters(self, module):
if hasattr(module, 'weight') and module.weight is not None:
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
torch.nn.init.zeros_(module.bias)
def forward(self, x):
"""
x: 输入数据 [B, T, N, C]
自动生成图结构,无需外部提供邻接矩阵
"""
x = x[..., :self.input_dim]
x_enc = rearrange(x, 'b t n c -> b n c t')
# 原版Patch
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, d_model * input_dim)
# 应用图增强编码器(自动生成图结构)
graph_enhanced = self.graph_encoder(enc_out) # (B, N, K * hidden_dim)
# 特征融合 - 现在两个张量都是三维的 [B, N, d_model]
enc_out = torch.cat([enc_out, graph_enhanced], dim=-1)
enc_out = self.feature_fusion(enc_out)
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0))
source_embeddings = self.word_embeddings[masks==1]
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
dec_out = self.out_mlp(enc_out) #[B, N, T*C]
B, N, _ = dec_out.shape
outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
return outputs