import torch import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Model from einops import rearrange from model.AEPSA.normalizer import GumbelSoftmax from model.AEPSA.reprogramming import PatchEmbedding, ReprogrammingLayer import torch.nn.functional as F class DynamicGraphEnhancer(nn.Module): """ 动态图增强器,基于节点嵌入自动生成图结构 参考DDGCRN的设计,使用节点嵌入和特征信息动态计算邻接矩阵 """ def __init__(self, num_nodes, in_dim, embed_dim=10): super().__init__() self.num_nodes = num_nodes self.embed_dim = embed_dim # 节点嵌入参数 self.node_embeddings = nn.Parameter( torch.randn(num_nodes, embed_dim), requires_grad=True ) # 特征转换层,用于生成动态调整的嵌入 self.feature_transform = nn.Sequential( nn.Linear(in_dim, 16), nn.Sigmoid(), nn.Linear(16, 2), nn.Sigmoid(), nn.Linear(2, embed_dim) ) # 注册单位矩阵作为固定的支持矩阵 self.register_buffer("eye", torch.eye(num_nodes)) def get_laplacian(self, graph, I, normalize=True): """ 计算归一化拉普拉斯矩阵 """ # 计算度矩阵的逆平方根 D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题 if normalize: return torch.matmul(torch.matmul(D_inv, graph), D_inv) else: return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) def forward(self, X): """ X: 输入特征 [B, N, D] 返回: 动态生成的归一化拉普拉斯矩阵 [B, N, N] """ batch_size = X.size(0) laplacians = [] # 获取单位矩阵 I = self.eye.to(X.device) for b in range(batch_size): # 使用特征转换层生成动态嵌入调整因子 filt = self.feature_transform(X[b]) # [N, embed_dim] # 计算节点嵌入向量 nodevec = torch.tanh(self.node_embeddings * filt) # 通过节点嵌入的点积计算邻接矩阵 adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # 计算归一化拉普拉斯矩阵 laplacian = self.get_laplacian(adj, I) laplacians.append(laplacian) return torch.stack(laplacians, dim=0) class GraphEnhancedEncoder(nn.Module): """ 基于Chebyshev多项式和动态拉普拉斯矩阵的图增强编码器 用于将动态图结构信息整合到特征编码中 """ def __init__(self, K=3, in_dim=64, hidden_dim=32, num_nodes=325, embed_dim=10, device='cpu'): super().__init__() self.K = K # Chebyshev多项式阶数 self.in_dim = in_dim self.hidden_dim = hidden_dim self.device = device # 动态图增强器 self.graph_enhancer = DynamicGraphEnhancer(num_nodes, in_dim, embed_dim) # 谱系数 α_k (可学习参数) self.alpha = nn.Parameter(torch.randn(K + 1, 1)) # 传播权重 W_k (可学习参数) self.W = nn.ParameterList([ nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1) ]) self.to(device) def chebyshev_polynomials(self, L_tilde, X): """递归计算 [T_0(L_tilde)X, ..., T_K(L_tilde)X]""" T_k_list = [X] if self.K >= 1: T_k_list.append(torch.matmul(L_tilde, X)) for k in range(2, self.K + 1): T_k_list.append(2 * torch.matmul(L_tilde, T_k_list[-1]) - T_k_list[-2]) return T_k_list def forward(self, X): """ X: 输入特征 [B, N, D] 返回: 增强后的特征 [B, N, hidden_dim*(K+1)] """ batch_size, num_nodes, _ = X.shape enhanced_features = [] # 动态生成拉普拉斯矩阵 laplacians = self.graph_enhancer(X) for b in range(batch_size): L = laplacians[b] # 特征值缩放 try: lambda_max = torch.linalg.eigvalsh(L).max().real # 避免除零问题 if lambda_max < 1e-6: lambda_max = 1.0 L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device) except: # 如果计算特征值失败,使用单位矩阵 L_tilde = torch.eye(num_nodes, device=X.device) # 计算Chebyshev多项式展开 T_k_list = self.chebyshev_polynomials(L_tilde, X[b]) H_list = [] # 应用传播权重 for k in range(self.K + 1): H_k = torch.matmul(T_k_list[k], self.W[k]) H_list.append(H_k) # 拼接所有尺度的特征 X_enhanced = torch.cat(H_list, dim=-1) # [N, hidden_dim*(K+1)] enhanced_features.append(X_enhanced) return torch.stack(enhanced_features, dim=0) class AEPSA(nn.Module): def __init__(self, configs): super(AEPSA, self).__init__() self.device = configs['device'] self.pred_len = configs['pred_len'] self.seq_len = configs['seq_len'] self.patch_len = configs['patch_len'] self.input_dim = configs['input_dim'] self.stride = configs['stride'] self.dropout = configs['dropout'] self.gpt_layers = configs['gpt_layers'] self.d_ff = configs['d_ff'] self.gpt_path = configs['gpt_path'] self.num_nodes = configs.get('num_nodes', 325) # 添加节点数量配置 self.word_choice = GumbelSoftmax(configs['word_num']) self.d_model = configs['d_model'] self.n_heads = configs['n_heads'] self.d_keys = None self.d_llm = 768 self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2) self.head_nf = self.d_ff * self.patch_nums # 词嵌入 self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim) # GPT2初始化 self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) self.gpts.h = self.gpts.h[:self.gpt_layers] self.gpts.apply(self.reset_parameters) self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) self.vocab_size = self.word_embeddings.shape[0] self.mapping_layer = nn.Linear(self.vocab_size, 1) self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm) # 添加动态图增强编码器 self.graph_encoder = GraphEnhancedEncoder( K=configs.get('chebyshev_order', 3), in_dim=self.d_model, hidden_dim=configs.get('graph_hidden_dim', 32), num_nodes=self.num_nodes, embed_dim=configs.get('graph_embed_dim', 10), device=self.device ) # 特征融合层 self.feature_fusion = nn.Linear( self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), self.d_model ) self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), nn.ReLU(), nn.Linear(128, self.pred_len) ) for i, (name, param) in enumerate(self.gpts.named_parameters()): if 'wpe' in name: param.requires_grad = True else: param.requires_grad = False def reset_parameters(self, module): if hasattr(module, 'weight') and module.weight is not None: torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if hasattr(module, 'bias') and module.bias is not None: torch.nn.init.zeros_(module.bias) def forward(self, x): """ x: 输入数据 [B, T, N, C] 自动生成图结构,无需外部提供邻接矩阵 """ x = x[..., :1] x_enc = rearrange(x, 'b t n c -> b n c t') # 原版Patch enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C) # 应用图增强编码器(自动生成图结构) graph_enhanced = self.graph_encoder(enc_out) # 特征融合 - 现在两个张量都是三维的 [B, N, d_model] enc_out = torch.cat([enc_out, graph_enhanced], dim=-1) enc_out = self.feature_fusion(enc_out) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0)) source_embeddings = self.word_embeddings[masks==1] enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state dec_out = self.out_mlp(enc_out) outputs = dec_out.unsqueeze(dim=-1) outputs = outputs.repeat(1, 1, 1, n_vars) outputs = outputs.permute(0,2,1,3) return outputs