From f3b5bdfc2807de5acdebd6ac93ad99af71d81b90 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Fri, 28 Nov 2025 21:32:49 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0v2=E5=8F=AF=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/AEPSA/v2_SolarEnergy.yaml | 2 +- model/AEPSA/aepsav2.py | 423 ++++++++----------------------- 2 files changed, 113 insertions(+), 312 deletions(-) diff --git a/config/AEPSA/v2_SolarEnergy.yaml b/config/AEPSA/v2_SolarEnergy.yaml index 26f286f..6f036fa 100644 --- a/config/AEPSA/v2_SolarEnergy.yaml +++ b/config/AEPSA/v2_SolarEnergy.yaml @@ -1,7 +1,7 @@ basic: dataset: "SolarEnergy" mode : "train" - device : "cuda:1" + device : "cuda:0" model: "AEPSA_v2" seed: 2023 diff --git a/model/AEPSA/aepsav2.py b/model/AEPSA/aepsav2.py index ceaa35d..aac9149 100644 --- a/model/AEPSA/aepsav2.py +++ b/model/AEPSA/aepsav2.py @@ -6,311 +6,150 @@ from model.AEPSA.normalizer import GumbelSoftmax from model.AEPSA.reprogramming import ReprogrammingLayer import torch.nn.functional as F -# 该文件实现了基于动态图增强的时空序列预测模型 -# 主要包含三个类:DynamicGraphEnhancer(动态图增强器)、GraphEnhancedEncoder(图增强编码器)和AEPSA(主模型) -# 每个操作都标注了输入输出shape以帮助理解数据流向 +# 基于动态图增强的时空序列预测模型实现 class DynamicGraphEnhancer(nn.Module): - """ - 动态图增强器,基于节点嵌入自动生成图结构 - 参考DDGCRN的设计,使用节点嵌入和特征信息动态计算邻接矩阵 - """ + """动态图增强编码器""" def __init__(self, num_nodes, in_dim, embed_dim=10): - # num_nodes: 节点数量 - # in_dim: 输入特征维度 - # embed_dim: 节点嵌入维度 super().__init__() - self.num_nodes = num_nodes - self.embed_dim = embed_dim + self.num_nodes = num_nodes # 节点个数 + self.embed_dim = embed_dim # 节点嵌入维度 - # 节点嵌入参数 [num_nodes, embed_dim] - self.node_embeddings = nn.Parameter( - torch.randn(num_nodes, embed_dim), requires_grad=True + self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) # 节点嵌入参数 + + self.feature_transform = nn.Sequential( # 特征转换网络 + nn.Linear(in_dim, 16), + nn.Sigmoid(), + nn.Linear(16, 2), + nn.Sigmoid(), + nn.Linear(2, embed_dim) ) - # 特征转换层,用于生成动态调整的嵌入 - # 输入: [N, in_dim] -> 输出: [N, embed_dim] - self.feature_transform = nn.Sequential( - nn.Linear(in_dim, 16), # [N, in_dim] -> [N, 16] - nn.Sigmoid(), - nn.Linear(16, 2), # [N, 16] -> [N, 2] - nn.Sigmoid(), - nn.Linear(2, embed_dim) # [N, 2] -> [N, embed_dim] - ) - - # 注册单位矩阵作为固定的支持矩阵 [num_nodes, num_nodes] - self.register_buffer("eye", torch.eye(num_nodes)) + self.register_buffer("eye", torch.eye(num_nodes)) # 注册单位矩阵 def get_laplacian(self, graph, I, normalize=True): - """ - 计算归一化拉普拉斯矩阵 - - 参数: - graph: 邻接矩阵 [N, N] - I: 单位矩阵 [N, N] - normalize: 是否使用标准化拉普拉斯矩阵 - - 返回: - laplacian: 归一化拉普拉斯矩阵 [N, N] - """ - # 计算度矩阵的逆平方根 [N, N] - D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) # [N, N] + D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) # 度矩阵的逆平方根 D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题 - if normalize: - # 归一化拉普拉斯矩阵: D^(-1/2) * graph * D^(-1/2) [N, N] - return torch.matmul(torch.matmul(D_inv, graph), D_inv) # [N, N] + return torch.matmul(torch.matmul(D_inv, graph), D_inv) # 归一化拉普拉斯矩阵 else: - # 拉普拉斯矩阵: D^(-1/2) * (graph + I) * D^(-1/2) [N, N] - return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) # [N, N] + return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) # 带自环的归一化拉普拉斯矩阵 def forward(self, X): - """ - 参数: - X: 输入特征 [B, N, D],其中B为批次大小,N为节点数,D为特征维度 - - 返回: - laplacians: 动态生成的归一化拉普拉斯矩阵 [B, N, N] - """ - batch_size = X.size(0) - laplacians = [] - - # 获取单位矩阵 [N, N] - I = self.eye.to(X.device) + """生成动态拉普拉斯矩阵""" + batch_size = X.size(0) # 批次大小 + laplacians = [] # 存储各批次的拉普拉斯矩阵 + I = self.eye.to(X.device) # 移动单位矩阵到目标设备 for b in range(batch_size): - # 使用特征转换层生成动态嵌入调整因子 [N, embed_dim] - filt = self.feature_transform(X[b]) # [N, embed_dim] - - # 计算节点嵌入向量 [N, embed_dim] - nodevec = torch.tanh(self.node_embeddings * filt) # [N, embed_dim] - - # 通过节点嵌入的点积计算邻接矩阵 [N, N] - adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # [N, N] - - # 计算归一化拉普拉斯矩阵 [N, N] - laplacian = self.get_laplacian(adj, I) # [N, N] + filt = self.feature_transform(X[b]) # 特征转换 + nodevec = torch.tanh(self.node_embeddings * filt) # 计算节点嵌入 + adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # 计算邻接矩阵 + laplacian = self.get_laplacian(adj, I) # 计算拉普拉斯矩阵 laplacians.append(laplacian) - - # 堆叠所有批次的拉普拉斯矩阵 [B, N, N] - return torch.stack(laplacians, dim=0) # [B, N, N] + return torch.stack(laplacians, dim=0) # 堆叠并返回 class GraphEnhancedEncoder(nn.Module): - """ - 基于Chebyshev多项式和动态拉普拉斯矩阵的图增强编码器 - 用于将动态图结构信息整合到特征编码中 - 优化版本:支持直接处理原始时间序列输入 - """ + """图增强编码器""" def __init__(self, K=3, in_dim=64, hidden_dim=32, num_nodes=325, embed_dim=10, device='cpu', temporal_dim=12, num_features=1): - # K: Chebyshev多项式阶数 - # in_dim: 输入特征维度 - # hidden_dim: 隐藏层维度 - # num_nodes: 节点数量 - # embed_dim: 嵌入维度 - # temporal_dim: 时间序列长度 - # num_features: 特征通道数量 super().__init__() self.K = K # Chebyshev多项式阶数 - self.in_dim = in_dim - self.hidden_dim = hidden_dim - self.device = device - self.temporal_dim = temporal_dim - self.num_features = num_features + self.in_dim = in_dim # 输入特征维度 + self.hidden_dim = hidden_dim # 隐藏层维度 + self.device = device # 运行设备 + self.temporal_dim = temporal_dim # 时间序列长度 + self.num_features = num_features # 特征通道数量 - # 输入预处理层,适配正确的通道维度 - # 输入: [B, C, N, T] -> 输出: [B, in_dim, N, 1] - self.input_projection = nn.Sequential( - # 2D卷积:[B, C, N, T] -> [B, 16, N, T] - nn.Conv2d(num_features, 16, kernel_size=(1, 3), padding=(0, 1)), # [B, C, N, T] -> [B, 16, N, T] + self.input_projection = nn.Sequential( # 输入投影层 + nn.Conv2d(num_features, 16, kernel_size=(1, 3), padding=(0, 1)), nn.ReLU(), - # 2D卷积:[B, 16, N, T] -> [B, in_dim, N, 1],时间维度上的全局卷积 - nn.Conv2d(16, in_dim, kernel_size=(1, temporal_dim)), # [B, 16, N, T] -> [B, in_dim, N, 1] + nn.Conv2d(16, in_dim, kernel_size=(1, temporal_dim)), nn.ReLU() ) - # 动态图增强器,用于生成动态拉普拉斯矩阵 - # 输入: [B, N, in_dim] -> 输出: [B, N, N] - self.graph_enhancer = DynamicGraphEnhancer(num_nodes, in_dim, embed_dim) - - # 谱系数 α_k (可学习参数) [K+1, 1] - self.alpha = nn.Parameter(torch.randn(K + 1, 1)) - - # 传播权重 W_k (可学习参数) - # 每个权重将Chebyshev多项式展开的结果从in_dim映射到hidden_dim - # 输入: [N, in_dim] -> 输出: [N, hidden_dim] - self.W = nn.ParameterList([ - nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1) - ]) - - self.to(device) + self.graph_enhancer = DynamicGraphEnhancer(num_nodes, in_dim, embed_dim) # 动态图增强器 + self.alpha = nn.Parameter(torch.randn(K + 1, 1)) # 谱系数 + self.W = nn.ParameterList([nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)]) # 传播权重 + self.to(device) # 移动到指定设备 def chebyshev_polynomials(self, L_tilde, X): - """ - 递归计算Chebyshev多项式展开 [T_0(L_tilde)X, ..., T_K(L_tilde)X] - - 参数: - L_tilde: 归一化拉普拉斯矩阵 [N, N] - X: 输入特征 [N, in_dim] - - 返回: - T_k_list: Chebyshev多项式展开列表 [K+1, N, in_dim] - """ - # T_0(X) = X [N, in_dim] - T_k_list = [X] + """计算Chebyshev多项式展开""" + T_k_list = [X] # T_0(X) = X if self.K >= 1: - # T_1(X) = L_tilde * X [N, in_dim] - T_k_list.append(torch.matmul(L_tilde, X)) + T_k_list.append(torch.matmul(L_tilde, X)) # T_1(X) = L_tilde * X for k in range(2, self.K + 1): - # T_k(X) = 2*L_tilde*T_{k-1}(X) - T_{k-2}(X) [N, in_dim] - T_k_list.append(2 * torch.matmul(L_tilde, T_k_list[-1]) - T_k_list[-2]) - # 返回Chebyshev多项式展开列表 [K+1, N, in_dim] - return T_k_list + T_k_list.append(2 * torch.matmul(L_tilde, T_k_list[-1]) - T_k_list[-2]) # 递推计算 + return T_k_list # 返回多项式列表 def forward(self, X): - """ - 参数: - X: 输入特征 [B, N, C, T] 或 [B, N, T](单特征情况) - B: 批次大小, N: 节点数, C: 特征通道数, T: 时间序列长度 - - 返回: - 增强后的特征 [B, N, hidden_dim*(K+1)] - """ - batch_size = X.size(0) - num_nodes = X.size(1) + """输入特征[B,N,C,T],返回增强特征[B,N,hidden_dim*(K+1)]""" + batch_size = X.size(0) # 批次大小 + num_nodes = X.size(1) # 节点数量 + + x = X.permute(0, 2, 1, 3) # [B,C,N,T] + x_proj = self.input_projection(x).squeeze(-1) # [B,in_dim,N] + x_proj = x_proj.permute(0, 2, 1) # [B,N,in_dim] - # 处理不同维度的输入 - if len(X.shape) == 4: # [B, N, C, T] - # 输入: [B, N, C, T] -> 输出: [B, C, N, T] - # 将输入转换为[B, C, N, T]格式,适合2D卷积 - x = X.permute(0, 2, 1, 3) # [B, C, N, T] - else: # [B, N, T] - # 输入: [B, N, T] -> 输出: [B, 1, N, T] - # 添加通道维度 - x = X.unsqueeze(1) # [B, 1, N, T] - - # 使用卷积投影层处理时间维度 - # 输入: [B, C, N, T] -> 输出: [B, in_dim, N, 1] - x_proj = self.input_projection(x) - # 输入: [B, in_dim, N, 1] -> 输出: [B, in_dim, N] - x_proj = x_proj.squeeze(-1) # [B, in_dim, N] - # 输入: [B, in_dim, N] -> 输出: [B, N, in_dim] - x_proj = x_proj.permute(0, 2, 1) # [B, N, in_dim] - - enhanced_features = [] - - # 动态生成拉普拉斯矩阵 - # 输入: [B, N, in_dim] -> 输出: [B, N, N] - laplacians = self.graph_enhancer(x_proj) # [B, N, N] + enhanced_features = [] # 存储增强特征 + laplacians = self.graph_enhancer(x_proj) # 生成动态拉普拉斯矩阵 for b in range(batch_size): - # 获取当前批次的拉普拉斯矩阵 [N, N] - L = laplacians[b] # [N, N] + L = laplacians[b] # 当前批次的拉普拉斯矩阵 # 特征值缩放 try: - # 计算最大特征值 [1] - lambda_max = torch.linalg.eigvalsh(L).max().real # [1] - # 避免除零问题 - if lambda_max < 1e-6: - lambda_max = 1.0 - # 缩放拉普拉斯矩阵到[-1, 1]区间 [N, N] - L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device) # [N, N] + lambda_max = torch.linalg.eigvalsh(L).max().real # 最大特征值 + lambda_max = 1.0 if lambda_max < 1e-6 else lambda_max # 防止除零 + L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device) # 归一化拉普拉斯 except: - # 如果计算特征值失败,使用单位矩阵 [N, N] - L_tilde = torch.eye(num_nodes, device=X.device) # [N, N] + L_tilde = torch.eye(num_nodes, device=X.device) # 异常处理 - # 计算Chebyshev多项式展开 - # 输入: L_tilde [N, N], x_proj [N, in_dim] -> 输出: [K+1, N, in_dim] - T_k_list = self.chebyshev_polynomials(L_tilde, x_proj[b]) # [K+1, N, in_dim] - H_list = [] - - # 应用传播权重 - for k in range(self.K + 1): - # 矩阵乘法: [N, in_dim] × [in_dim, hidden_dim] -> [N, hidden_dim] - H_k = torch.matmul(T_k_list[k], self.W[k]) # [N, hidden_dim] - H_list.append(H_k) - - # 拼接所有尺度的特征 - # 输入: [K+1, N, hidden_dim] -> 输出: [N, hidden_dim*(K+1)] - X_enhanced = torch.cat(H_list, dim=-1) # [N, hidden_dim*(K+1)] + # 计算展开并应用权重 + T_k_list = self.chebyshev_polynomials(L_tilde, x_proj[b]) # 计算Chebyshev多项式 + H_list = [torch.matmul(T_k_list[k], self.W[k]) for k in range(self.K + 1)] # 应用权重 + X_enhanced = torch.cat(H_list, dim=-1) # 拼接特征 enhanced_features.append(X_enhanced) - # 堆叠所有批次的增强特征 - # 输入: [B, N, hidden_dim*(K+1)] -> 输出: [B, N, hidden_dim*(K+1)] - return torch.stack(enhanced_features, dim=0) # [B, N, hidden_dim*(K+1)] + return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)],每个节点在每个k阶下的切比雪夫特征 class AEPSA(nn.Module): - """ - 自适应特征投影时空自注意力模型(AEPSA) - 整合动态图增强和预训练语言模型进行时空序列预测 - """ + """自适应特征投影时空自注意力模型""" def __init__(self, configs): - # configs: 包含模型所有配置的字典 - # 主要配置参数说明: - # device: 运行设备 - # pred_len: 预测序列长度 - # seq_len: 输入序列长度 - # patch_len: 补丁长度(已移除对应组件) - # input_dim: 输入特征维度 - # stride: 步长(已移除对应组件) - # dropout: Dropout概率 - # gpt_layers: 使用的GPT2层数 - # d_ff: 前馈网络隐藏层维度 - # gpt_path: 预训练GPT2模型路径 - # num_nodes: 节点数量 - # word_num: GumbelSoftmax词汇数量 - # d_model: 模型维度 - # n_heads: 注意力头数量 - # chebyshev_order: Chebyshev多项式阶数 - # graph_hidden_dim: 图编码器隐藏层维度 - # graph_embed_dim: 图编码器嵌入维度 super(AEPSA, self).__init__() - self.device = configs['device'] - self.pred_len = configs['pred_len'] - self.seq_len = configs['seq_len'] - self.patch_len = configs['patch_len'] - self.input_dim = configs['input_dim'] - self.stride = configs['stride'] - self.dropout = configs['dropout'] - self.gpt_layers = configs['gpt_layers'] - self.d_ff = configs['d_ff'] - self.gpt_path = configs['gpt_path'] - self.num_nodes = configs.get('num_nodes', 325) # 添加节点数量配置 + self.device = configs['device'] # 运行设备 + self.pred_len = configs['pred_len'] # 预测序列长度 + self.seq_len = configs['seq_len'] # 输入序列长度 + self.patch_len = configs['patch_len'] # 补丁长度 + self.input_dim = configs['input_dim'] # 输入特征维度 + self.stride = configs['stride'] # 步长 + self.dropout = configs['dropout'] # Dropout概率 + self.gpt_layers = configs['gpt_layers'] # 使用的GPT2层数 + self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度 + self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径 + self.num_nodes = configs.get('num_nodes', 325) # 节点数量 - # GumbelSoftmax层,用于词汇选择 - # 输入: [vocab_size] -> 输出: [vocab_size](one-hot近似分布) - self.word_choice = GumbelSoftmax(configs['word_num']) + self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层 - self.d_model = configs['d_model'] - self.n_heads = configs['n_heads'] - self.d_keys = None + self.d_model = configs['d_model'] # 模型维度 + self.n_heads = configs['n_heads'] # 注意力头数量 + self.d_keys = None # 键维度 self.d_llm = 768 # GPT2隐藏层维度 - self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2) - self.head_nf = self.d_ff * self.patch_nums + self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2) # 补丁数量 + self.head_nf = self.d_ff * self.patch_nums # 头特征维度 - # 移除不再使用的patch_embedding层 - - # GPT2初始化 - # 加载预训练GPT2模型,输出注意力权重和隐藏状态 - self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) + # 初始化GPT2模型 + self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) # GPT2模型 self.gpts.h = self.gpts.h[:self.gpt_layers] # 截取指定层数 - self.gpts.apply(self.reset_parameters) + self.gpts.apply(self.reset_parameters) # 重置参数 - # 获取GPT2词嵌入权重 - # 形状: [vocab_size, d_llm] - self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) - self.vocab_size = self.word_embeddings.shape[0] - # 映射层,将词汇表维度映射到1维 - # 输入: [vocab_size] -> 输出: [1] - self.mapping_layer = nn.Linear(self.vocab_size, 1) - # 重编程层,用于特征映射和注意力计算 - # 输入: [B, N, d_model], [d_llm], [d_llm] -> 输出: [B, N, d_model] - self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm) + self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) # 词嵌入权重 + self.vocab_size = self.word_embeddings.shape[0] # 词汇表大小 + self.mapping_layer = nn.Linear(self.vocab_size, 1) # 映射层 + self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm) # 重编程层 - # 动态图增强编码器 - # 输入: [B, N, C, T] -> 输出: [B, N, hidden_dim*(K+1)] + # 初始化图增强编码器 self.graph_encoder = GraphEnhancedEncoder( K=configs.get('chebyshev_order', 3), # Chebyshev多项式阶数 in_dim=self.d_model, # 输入特征维度 @@ -322,11 +161,9 @@ class AEPSA(nn.Module): num_features=self.input_dim # 特征通道数 ) - # 图特征投影层,将图增强特征维度转换为d_model - # 输入: [B, N, hidden_dim*(K+1)] -> 输出: [B, N, d_model] - self.graph_projection = nn.Linear( - configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), - self.d_model + self.graph_projection = nn.Linear( # 图特征投影层,每一k阶的切比雪夫权重映射到隐藏维度 + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), # 输入维度 + self.d_model # 输出维度 ) self.out_mlp = nn.Sequential( @@ -335,11 +172,9 @@ class AEPSA(nn.Module): nn.Linear(128, self.pred_len) ) - for i, (name, param) in enumerate(self.gpts.named_parameters()): - if 'wpe' in name: - param.requires_grad = True - else: - param.requires_grad = False + # 设置参数可训练性 wps=word position embeddings + for name, param in self.gpts.named_parameters(): + param.requires_grad = 'wpe' in name def reset_parameters(self, module): if hasattr(module, 'weight') and module.weight is not None: @@ -348,60 +183,26 @@ class AEPSA(nn.Module): torch.nn.init.zeros_(module.bias) def forward(self, x): - """ - 前向传播函数 - 输入: - x: 输入数据 [B, T, N, C],其中B为批次大小,T为时间步长,N为节点数,C为特征通道数 + # 数据处理 + x = x[..., :1] # [B,T,N,1] + x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T] - 返回: - outputs: 预测结果 [B, pred_len, N, 1] - """ - # 只保留第一个特征通道 - # 输入: [B, T, N, C] -> 输出: [B, T, N, 1] - x = x[..., :1] # [B, T, N, 1] + # 图编码 + graph_enhanced = self.graph_encoder(x_enc) # [B,N,1,T] -> [B, N, hidden_dim*(K+1)] + enc_out = self.graph_projection(graph_enhanced) # [B,N,d_model] - # 调整输入维度以适配图编码器 - # 输入: [B, T, N, 1] -> 输出: [B, N, 1, T] - x_enc = rearrange(x, 'b t n c -> b n c t') # [B, N, 1, T] - - # 应用图增强编码器获取增强特征 - # 输入: [B, N, 1, T] -> 输出: [B, N, hidden_dim*(K+1)] - graph_enhanced = self.graph_encoder(x_enc) # [B, N, hidden_dim*(K+1)] - - # 投影图增强特征到模型维度 - # 输入: [B, N, hidden_dim*(K+1)] -> 输出: [B, N, d_model] - enc_out = self.graph_projection(graph_enhanced) # [B, N, d_model] - - # 处理词嵌入权重,为注意力机制准备 - # 输入: [vocab_size, d_llm] -> 输出: [d_llm, vocab_size] -> [d_llm, vocab_size] - self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) # [vocab_size, d_llm] - - # 使用GumbelSoftmax选择词汇 - # 输入: [d_llm, 1] -> 输出: [d_llm, 1] - masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0)) # [d_llm, 1] - - # 获取选中的源嵌入 - # 输入: [vocab_size, d_llm] 与 masks -> 输出: [selected_words, d_llm] - source_embeddings = self.word_embeddings[masks==1] # [selected_words, d_llm] + # 词嵌入处理 + self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) + masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0)) # [d_llm,1] + source_embeddings = self.word_embeddings[masks==1] # [selected_words,d_llm] - # 应用重编程层处理特征和源嵌入 - # 输入: [B, N, d_model], [selected_words, d_llm], [selected_words, d_llm] -> 输出: [B, N, d_model] - enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) # [B, N, d_model] + # 重编程与预测 + enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) + enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state # [B,N,d_llm] + dec_out = self.out_mlp(enc_out) # [B,N,pred_len] - # 通过GPT2模型处理增强特征 - # 输入: [B, N, d_model] -> 输出: [B, N, d_llm] - enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state # [B, N, d_llm] - - # 使用MLP预测未来时间步 - # 输入: [B, N, d_llm] -> 输出: [B, N, pred_len] - dec_out = self.out_mlp(enc_out) # [B, N, pred_len] - - # 添加通道维度 - # 输入: [B, N, pred_len] -> 输出: [B, N, pred_len, 1] - outputs = dec_out.unsqueeze(dim=-1) # [B, N, pred_len, 1] - - # 调整维度顺序为 [B, pred_len, N, 1] - # 输入: [B, N, pred_len, 1] -> 输出: [B, pred_len, N, 1] - outputs = outputs.permute(0, 2, 1, 3) # [B, pred_len, N, 1] + # 维度调整 + outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1] + outputs = outputs.permute(0, 2, 1, 3) # [B,pred_len,N,1] - return outputs # [B, pred_len, N, 1] \ No newline at end of file + return outputs \ No newline at end of file