diff --git a/config/ASTRA/AirQuality.yaml b/config/ASTRA/AirQuality.yaml index 455fc4b..7d4868e 100644 --- a/config/ASTRA/AirQuality.yaml +++ b/config/ASTRA/AirQuality.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 6 train: batch_size: 16 diff --git a/config/ASTRA_v2/AirQuality.yaml b/config/ASTRA_v2/AirQuality.yaml index 10796d2..ed22962 100644 --- a/config/ASTRA_v2/AirQuality.yaml +++ b/config/ASTRA_v2/AirQuality.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 6 train: batch_size: 16 diff --git a/config/ASTRA_v3/AirQuality.yaml b/config/ASTRA_v3/AirQuality.yaml index 68e6acc..d4cb947 100644 --- a/config/ASTRA_v3/AirQuality.yaml +++ b/config/ASTRA_v3/AirQuality.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 6 train: batch_size: 16 diff --git a/config/STAEFormer/SolarEnergy.yaml b/config/STAEFormer/SolarEnergy.yaml index c1151ca..a3fed30 100644 --- a/config/STAEFormer/SolarEnergy.yaml +++ b/config/STAEFormer/SolarEnergy.yaml @@ -10,7 +10,7 @@ data: column_wise: false days_per_week: 7 horizon: 24 - input_dim: 137 + input_dim: 1 lag: 24 normalizer: std num_nodes: 137 diff --git a/model/ASTRA/astra.py b/model/ASTRA/astra.py index 0ed2333..71d4ee9 100644 --- a/model/ASTRA/astra.py +++ b/model/ASTRA/astra.py @@ -7,22 +7,15 @@ from model.ASTRA.reprogramming import PatchEmbedding, ReprogrammingLayer import torch.nn.functional as F class DynamicGraphEnhancer(nn.Module): - """ - 动态图增强器,基于节点嵌入自动生成图结构 - 参考DDGCRN的设计,使用节点嵌入和特征信息动态计算邻接矩阵 - """ + """动态图增强编码器""" def __init__(self, num_nodes, in_dim, embed_dim=10): super().__init__() - self.num_nodes = num_nodes - self.embed_dim = embed_dim + self.num_nodes = num_nodes # 节点个数 + self.embed_dim = embed_dim # 节点嵌入维度 - # 节点嵌入参数 - self.node_embeddings = nn.Parameter( - torch.randn(num_nodes, embed_dim), requires_grad=True - ) + self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) # 节点嵌入参数 - # 特征转换层,用于生成动态调整的嵌入 - self.feature_transform = nn.Sequential( + self.feature_transform = nn.Sequential( # 特征转换网络 nn.Linear(in_dim, 16), nn.Sigmoid(), nn.Linear(16, 2), @@ -30,48 +23,29 @@ class DynamicGraphEnhancer(nn.Module): nn.Linear(2, embed_dim) ) - # 注册单位矩阵作为固定的支持矩阵 - self.register_buffer("eye", torch.eye(num_nodes)) + self.register_buffer("eye", torch.eye(num_nodes)) # 注册单位矩阵 def get_laplacian(self, graph, I, normalize=True): - """ - 计算归一化拉普拉斯矩阵 - """ - # 计算度矩阵的逆平方根 - D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) + D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) # 度矩阵的逆平方根 D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题 - if normalize: - return torch.matmul(torch.matmul(D_inv, graph), D_inv) + return torch.matmul(torch.matmul(D_inv, graph), D_inv) # 归一化拉普拉斯矩阵 else: - return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) + return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) # 带自环的归一化拉普拉斯矩阵 def forward(self, X): - """ - X: 输入特征 [B, N, D] - 返回: 动态生成的归一化拉普拉斯矩阵 [B, N, N] - """ - batch_size = X.size(0) - laplacians = [] - - # 获取单位矩阵 - I = self.eye.to(X.device) + """生成动态拉普拉斯矩阵""" + batch_size = X.size(0) # 批次大小 + laplacians = [] # 存储各批次的拉普拉斯矩阵 + I = self.eye.to(X.device) # 移动单位矩阵到目标设备 for b in range(batch_size): - # 使用特征转换层生成动态嵌入调整因子 - filt = self.feature_transform(X[b]) # [N, embed_dim] - - # 计算节点嵌入向量 - nodevec = torch.tanh(self.node_embeddings * filt) - - # 通过节点嵌入的点积计算邻接矩阵 - adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) - - # 计算归一化拉普拉斯矩阵 - laplacian = self.get_laplacian(adj, I) + filt = self.feature_transform(X[b]) # 特征转换 + nodevec = torch.tanh(self.node_embeddings * filt) # 计算节点嵌入 + adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # 计算邻接矩阵 + laplacian = self.get_laplacian(adj, I) # 计算拉普拉斯矩阵 laplacians.append(laplacian) - - return torch.stack(laplacians, dim=0) + return torch.stack(laplacians, dim=0) # 堆叠并返回 class GraphEnhancedEncoder(nn.Module): """ @@ -190,8 +164,8 @@ class ASTRA(nn.Module): # 添加动态图增强编码器 self.graph_encoder = GraphEnhancedEncoder( K=configs.get('chebyshev_order', 3), - in_dim=self.d_model, - hidden_dim=configs.get('graph_hidden_dim', 32), + in_dim=self.d_model * self.input_dim, + hidden_dim=self.d_model, num_nodes=self.num_nodes, embed_dim=configs.get('graph_embed_dim', 10), device=self.device @@ -199,14 +173,14 @@ class ASTRA(nn.Module): # 特征融合层 self.feature_fusion = nn.Linear( - self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), + self.d_model * self.input_dim + self.d_model * (configs.get('chebyshev_order', 3) + 1), self.d_model ) self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), nn.ReLU(), - nn.Linear(128, self.pred_len) + nn.Linear(128, self.pred_len * self.output_dim) ) for i, (name, param) in enumerate(self.gpts.named_parameters()): @@ -229,9 +203,9 @@ class ASTRA(nn.Module): x = x[..., :self.input_dim] x_enc = rearrange(x, 'b t n c -> b n c t') # 原版Patch - enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C) + enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, d_model * input_dim) # 应用图增强编码器(自动生成图结构) - graph_enhanced = self.graph_encoder(enc_out) + graph_enhanced = self.graph_encoder(enc_out) # (B, N, K * hidden_dim) # 特征融合 - 现在两个张量都是三维的 [B, N, d_model] enc_out = torch.cat([enc_out, graph_enhanced], dim=-1) enc_out = self.feature_fusion(enc_out) @@ -243,9 +217,10 @@ class ASTRA(nn.Module): enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state - dec_out = self.out_mlp(enc_out) - outputs = dec_out.unsqueeze(dim=-1) - outputs = outputs.repeat(1, 1, 1, n_vars) - outputs = outputs.permute(0,2,1,3) + dec_out = self.out_mlp(enc_out) #[B, N, T*C] + + B, N, _ = dec_out.shape + outputs = dec_out.view(B, N, self.pred_len, self.output_dim) + outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C return outputs diff --git a/model/ASTRA/astrav2.py b/model/ASTRA/astrav2.py index 6a47206..22e25b9 100644 --- a/model/ASTRA/astrav2.py +++ b/model/ASTRA/astrav2.py @@ -128,6 +128,7 @@ class ASTRA(nn.Module): self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度 self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径 self.num_nodes = configs.get('num_nodes', 325) # 节点数量 + self.output_dim = configs.get('output_dim', 1) self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层 @@ -169,7 +170,7 @@ class ASTRA(nn.Module): self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), nn.ReLU(), - nn.Linear(128, self.pred_len) + nn.Linear(128, self.pred_len * self.output_dim) ) # 设置参数可训练性 wps=word position embeddings @@ -202,9 +203,8 @@ class ASTRA(nn.Module): dec_out = self.out_mlp(enc_out) # [B,N,pred_len] # 维度调整 - dec_out = self.out_mlp(enc_out) - outputs = dec_out.unsqueeze(dim=-1) - outputs = outputs.repeat(1, 1, 1, self.input_dim) - outputs = outputs.permute(0,2,1,3) + B, N, _ = dec_out.shape + outputs = dec_out.view(B, N, self.pred_len, self.output_dim) + outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C return outputs \ No newline at end of file diff --git a/model/ASTRA/astrav3.py b/model/ASTRA/astrav3.py index 0e9aebf..59fc11d 100644 --- a/model/ASTRA/astrav3.py +++ b/model/ASTRA/astrav3.py @@ -128,6 +128,7 @@ class ASTRA(nn.Module): self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度 self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径 self.num_nodes = configs.get('num_nodes', 325) # 节点数量 + self.output_dim = configs.get('output_dim', 1) self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层 @@ -169,7 +170,7 @@ class ASTRA(nn.Module): self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), nn.ReLU(), - nn.Linear(128, self.pred_len) + nn.Linear(128, self.pred_len * self.output_dim) ) # 设置参数可训练性 wps=word position embeddings @@ -203,9 +204,8 @@ class ASTRA(nn.Module): dec_out = self.out_mlp(X_enc) # [B,N,pred_len] # 维度调整 - dec_out = self.out_mlp(enc_out) - outputs = dec_out.unsqueeze(dim=-1) - outputs = outputs.repeat(1, 1, 1, self.input_dim) - outputs = outputs.permute(0,2,1,3) + B, N, _ = dec_out.shape + outputs = dec_out.view(B, N, self.pred_len, self.output_dim) + outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C return outputs \ No newline at end of file