修复astra bug
This commit is contained in:
parent
85257bc61c
commit
b38e4a5da2
|
|
@ -32,6 +32,7 @@ model:
|
|||
seq_len: 24
|
||||
stride: 7
|
||||
word_num: 1000
|
||||
output_dim: 6
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ model:
|
|||
seq_len: 24
|
||||
stride: 7
|
||||
word_num: 1000
|
||||
output_dim: 6
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ model:
|
|||
seq_len: 24
|
||||
stride: 7
|
||||
word_num: 1000
|
||||
output_dim: 6
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ data:
|
|||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 137
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 137
|
||||
|
|
|
|||
|
|
@ -7,22 +7,15 @@ from model.ASTRA.reprogramming import PatchEmbedding, ReprogrammingLayer
|
|||
import torch.nn.functional as F
|
||||
|
||||
class DynamicGraphEnhancer(nn.Module):
|
||||
"""
|
||||
动态图增强器,基于节点嵌入自动生成图结构
|
||||
参考DDGCRN的设计,使用节点嵌入和特征信息动态计算邻接矩阵
|
||||
"""
|
||||
"""动态图增强编码器"""
|
||||
def __init__(self, num_nodes, in_dim, embed_dim=10):
|
||||
super().__init__()
|
||||
self.num_nodes = num_nodes
|
||||
self.embed_dim = embed_dim
|
||||
self.num_nodes = num_nodes # 节点个数
|
||||
self.embed_dim = embed_dim # 节点嵌入维度
|
||||
|
||||
# 节点嵌入参数
|
||||
self.node_embeddings = nn.Parameter(
|
||||
torch.randn(num_nodes, embed_dim), requires_grad=True
|
||||
)
|
||||
self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) # 节点嵌入参数
|
||||
|
||||
# 特征转换层,用于生成动态调整的嵌入
|
||||
self.feature_transform = nn.Sequential(
|
||||
self.feature_transform = nn.Sequential( # 特征转换网络
|
||||
nn.Linear(in_dim, 16),
|
||||
nn.Sigmoid(),
|
||||
nn.Linear(16, 2),
|
||||
|
|
@ -30,48 +23,29 @@ class DynamicGraphEnhancer(nn.Module):
|
|||
nn.Linear(2, embed_dim)
|
||||
)
|
||||
|
||||
# 注册单位矩阵作为固定的支持矩阵
|
||||
self.register_buffer("eye", torch.eye(num_nodes))
|
||||
self.register_buffer("eye", torch.eye(num_nodes)) # 注册单位矩阵
|
||||
|
||||
def get_laplacian(self, graph, I, normalize=True):
|
||||
"""
|
||||
计算归一化拉普拉斯矩阵
|
||||
"""
|
||||
# 计算度矩阵的逆平方根
|
||||
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
|
||||
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) # 度矩阵的逆平方根
|
||||
D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题
|
||||
|
||||
if normalize:
|
||||
return torch.matmul(torch.matmul(D_inv, graph), D_inv)
|
||||
return torch.matmul(torch.matmul(D_inv, graph), D_inv) # 归一化拉普拉斯矩阵
|
||||
else:
|
||||
return torch.matmul(torch.matmul(D_inv, graph + I), D_inv)
|
||||
return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) # 带自环的归一化拉普拉斯矩阵
|
||||
|
||||
def forward(self, X):
|
||||
"""
|
||||
X: 输入特征 [B, N, D]
|
||||
返回: 动态生成的归一化拉普拉斯矩阵 [B, N, N]
|
||||
"""
|
||||
batch_size = X.size(0)
|
||||
laplacians = []
|
||||
|
||||
# 获取单位矩阵
|
||||
I = self.eye.to(X.device)
|
||||
"""生成动态拉普拉斯矩阵"""
|
||||
batch_size = X.size(0) # 批次大小
|
||||
laplacians = [] # 存储各批次的拉普拉斯矩阵
|
||||
I = self.eye.to(X.device) # 移动单位矩阵到目标设备
|
||||
|
||||
for b in range(batch_size):
|
||||
# 使用特征转换层生成动态嵌入调整因子
|
||||
filt = self.feature_transform(X[b]) # [N, embed_dim]
|
||||
|
||||
# 计算节点嵌入向量
|
||||
nodevec = torch.tanh(self.node_embeddings * filt)
|
||||
|
||||
# 通过节点嵌入的点积计算邻接矩阵
|
||||
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1)))
|
||||
|
||||
# 计算归一化拉普拉斯矩阵
|
||||
laplacian = self.get_laplacian(adj, I)
|
||||
filt = self.feature_transform(X[b]) # 特征转换
|
||||
nodevec = torch.tanh(self.node_embeddings * filt) # 计算节点嵌入
|
||||
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # 计算邻接矩阵
|
||||
laplacian = self.get_laplacian(adj, I) # 计算拉普拉斯矩阵
|
||||
laplacians.append(laplacian)
|
||||
|
||||
return torch.stack(laplacians, dim=0)
|
||||
return torch.stack(laplacians, dim=0) # 堆叠并返回
|
||||
|
||||
class GraphEnhancedEncoder(nn.Module):
|
||||
"""
|
||||
|
|
@ -190,8 +164,8 @@ class ASTRA(nn.Module):
|
|||
# 添加动态图增强编码器
|
||||
self.graph_encoder = GraphEnhancedEncoder(
|
||||
K=configs.get('chebyshev_order', 3),
|
||||
in_dim=self.d_model,
|
||||
hidden_dim=configs.get('graph_hidden_dim', 32),
|
||||
in_dim=self.d_model * self.input_dim,
|
||||
hidden_dim=self.d_model,
|
||||
num_nodes=self.num_nodes,
|
||||
embed_dim=configs.get('graph_embed_dim', 10),
|
||||
device=self.device
|
||||
|
|
@ -199,14 +173,14 @@ class ASTRA(nn.Module):
|
|||
|
||||
# 特征融合层
|
||||
self.feature_fusion = nn.Linear(
|
||||
self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1),
|
||||
self.d_model * self.input_dim + self.d_model * (configs.get('chebyshev_order', 3) + 1),
|
||||
self.d_model
|
||||
)
|
||||
|
||||
self.out_mlp = nn.Sequential(
|
||||
nn.Linear(self.d_llm, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, self.pred_len)
|
||||
nn.Linear(128, self.pred_len * self.output_dim)
|
||||
)
|
||||
|
||||
for i, (name, param) in enumerate(self.gpts.named_parameters()):
|
||||
|
|
@ -229,9 +203,9 @@ class ASTRA(nn.Module):
|
|||
x = x[..., :self.input_dim]
|
||||
x_enc = rearrange(x, 'b t n c -> b n c t')
|
||||
# 原版Patch
|
||||
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C)
|
||||
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, d_model * input_dim)
|
||||
# 应用图增强编码器(自动生成图结构)
|
||||
graph_enhanced = self.graph_encoder(enc_out)
|
||||
graph_enhanced = self.graph_encoder(enc_out) # (B, N, K * hidden_dim)
|
||||
# 特征融合 - 现在两个张量都是三维的 [B, N, d_model]
|
||||
enc_out = torch.cat([enc_out, graph_enhanced], dim=-1)
|
||||
enc_out = self.feature_fusion(enc_out)
|
||||
|
|
@ -243,9 +217,10 @@ class ASTRA(nn.Module):
|
|||
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
||||
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
|
||||
|
||||
dec_out = self.out_mlp(enc_out)
|
||||
outputs = dec_out.unsqueeze(dim=-1)
|
||||
outputs = outputs.repeat(1, 1, 1, n_vars)
|
||||
outputs = outputs.permute(0,2,1,3)
|
||||
dec_out = self.out_mlp(enc_out) #[B, N, T*C]
|
||||
|
||||
B, N, _ = dec_out.shape
|
||||
outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
|
||||
outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
|
||||
|
||||
return outputs
|
||||
|
|
|
|||
|
|
@ -128,6 +128,7 @@ class ASTRA(nn.Module):
|
|||
self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度
|
||||
self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径
|
||||
self.num_nodes = configs.get('num_nodes', 325) # 节点数量
|
||||
self.output_dim = configs.get('output_dim', 1)
|
||||
|
||||
self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层
|
||||
|
||||
|
|
@ -169,7 +170,7 @@ class ASTRA(nn.Module):
|
|||
self.out_mlp = nn.Sequential(
|
||||
nn.Linear(self.d_llm, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, self.pred_len)
|
||||
nn.Linear(128, self.pred_len * self.output_dim)
|
||||
)
|
||||
|
||||
# 设置参数可训练性 wps=word position embeddings
|
||||
|
|
@ -202,9 +203,8 @@ class ASTRA(nn.Module):
|
|||
dec_out = self.out_mlp(enc_out) # [B,N,pred_len]
|
||||
|
||||
# 维度调整
|
||||
dec_out = self.out_mlp(enc_out)
|
||||
outputs = dec_out.unsqueeze(dim=-1)
|
||||
outputs = outputs.repeat(1, 1, 1, self.input_dim)
|
||||
outputs = outputs.permute(0,2,1,3)
|
||||
B, N, _ = dec_out.shape
|
||||
outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
|
||||
outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
|
||||
|
||||
return outputs
|
||||
|
|
@ -128,6 +128,7 @@ class ASTRA(nn.Module):
|
|||
self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度
|
||||
self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径
|
||||
self.num_nodes = configs.get('num_nodes', 325) # 节点数量
|
||||
self.output_dim = configs.get('output_dim', 1)
|
||||
|
||||
self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层
|
||||
|
||||
|
|
@ -169,7 +170,7 @@ class ASTRA(nn.Module):
|
|||
self.out_mlp = nn.Sequential(
|
||||
nn.Linear(self.d_llm, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, self.pred_len)
|
||||
nn.Linear(128, self.pred_len * self.output_dim)
|
||||
)
|
||||
|
||||
# 设置参数可训练性 wps=word position embeddings
|
||||
|
|
@ -203,9 +204,8 @@ class ASTRA(nn.Module):
|
|||
dec_out = self.out_mlp(X_enc) # [B,N,pred_len]
|
||||
|
||||
# 维度调整
|
||||
dec_out = self.out_mlp(enc_out)
|
||||
outputs = dec_out.unsqueeze(dim=-1)
|
||||
outputs = outputs.repeat(1, 1, 1, self.input_dim)
|
||||
outputs = outputs.permute(0,2,1,3)
|
||||
B, N, _ = dec_out.shape
|
||||
outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
|
||||
outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
|
||||
|
||||
return outputs
|
||||
Loading…
Reference in New Issue