修复astra bug
This commit is contained in:
parent
85257bc61c
commit
b38e4a5da2
|
|
@ -32,6 +32,7 @@ model:
|
||||||
seq_len: 24
|
seq_len: 24
|
||||||
stride: 7
|
stride: 7
|
||||||
word_num: 1000
|
word_num: 1000
|
||||||
|
output_dim: 6
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ model:
|
||||||
seq_len: 24
|
seq_len: 24
|
||||||
stride: 7
|
stride: 7
|
||||||
word_num: 1000
|
word_num: 1000
|
||||||
|
output_dim: 6
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ model:
|
||||||
seq_len: 24
|
seq_len: 24
|
||||||
stride: 7
|
stride: 7
|
||||||
word_num: 1000
|
word_num: 1000
|
||||||
|
output_dim: 6
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ data:
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
input_dim: 137
|
input_dim: 1
|
||||||
lag: 24
|
lag: 24
|
||||||
normalizer: std
|
normalizer: std
|
||||||
num_nodes: 137
|
num_nodes: 137
|
||||||
|
|
|
||||||
|
|
@ -7,22 +7,15 @@ from model.ASTRA.reprogramming import PatchEmbedding, ReprogrammingLayer
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
class DynamicGraphEnhancer(nn.Module):
|
class DynamicGraphEnhancer(nn.Module):
|
||||||
"""
|
"""动态图增强编码器"""
|
||||||
动态图增强器,基于节点嵌入自动生成图结构
|
|
||||||
参考DDGCRN的设计,使用节点嵌入和特征信息动态计算邻接矩阵
|
|
||||||
"""
|
|
||||||
def __init__(self, num_nodes, in_dim, embed_dim=10):
|
def __init__(self, num_nodes, in_dim, embed_dim=10):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_nodes = num_nodes
|
self.num_nodes = num_nodes # 节点个数
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim # 节点嵌入维度
|
||||||
|
|
||||||
# 节点嵌入参数
|
self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) # 节点嵌入参数
|
||||||
self.node_embeddings = nn.Parameter(
|
|
||||||
torch.randn(num_nodes, embed_dim), requires_grad=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 特征转换层,用于生成动态调整的嵌入
|
self.feature_transform = nn.Sequential( # 特征转换网络
|
||||||
self.feature_transform = nn.Sequential(
|
|
||||||
nn.Linear(in_dim, 16),
|
nn.Linear(in_dim, 16),
|
||||||
nn.Sigmoid(),
|
nn.Sigmoid(),
|
||||||
nn.Linear(16, 2),
|
nn.Linear(16, 2),
|
||||||
|
|
@ -30,48 +23,29 @@ class DynamicGraphEnhancer(nn.Module):
|
||||||
nn.Linear(2, embed_dim)
|
nn.Linear(2, embed_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册单位矩阵作为固定的支持矩阵
|
self.register_buffer("eye", torch.eye(num_nodes)) # 注册单位矩阵
|
||||||
self.register_buffer("eye", torch.eye(num_nodes))
|
|
||||||
|
|
||||||
def get_laplacian(self, graph, I, normalize=True):
|
def get_laplacian(self, graph, I, normalize=True):
|
||||||
"""
|
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) # 度矩阵的逆平方根
|
||||||
计算归一化拉普拉斯矩阵
|
|
||||||
"""
|
|
||||||
# 计算度矩阵的逆平方根
|
|
||||||
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
|
|
||||||
D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题
|
D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
return torch.matmul(torch.matmul(D_inv, graph), D_inv)
|
return torch.matmul(torch.matmul(D_inv, graph), D_inv) # 归一化拉普拉斯矩阵
|
||||||
else:
|
else:
|
||||||
return torch.matmul(torch.matmul(D_inv, graph + I), D_inv)
|
return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) # 带自环的归一化拉普拉斯矩阵
|
||||||
|
|
||||||
def forward(self, X):
|
def forward(self, X):
|
||||||
"""
|
"""生成动态拉普拉斯矩阵"""
|
||||||
X: 输入特征 [B, N, D]
|
batch_size = X.size(0) # 批次大小
|
||||||
返回: 动态生成的归一化拉普拉斯矩阵 [B, N, N]
|
laplacians = [] # 存储各批次的拉普拉斯矩阵
|
||||||
"""
|
I = self.eye.to(X.device) # 移动单位矩阵到目标设备
|
||||||
batch_size = X.size(0)
|
|
||||||
laplacians = []
|
|
||||||
|
|
||||||
# 获取单位矩阵
|
|
||||||
I = self.eye.to(X.device)
|
|
||||||
|
|
||||||
for b in range(batch_size):
|
for b in range(batch_size):
|
||||||
# 使用特征转换层生成动态嵌入调整因子
|
filt = self.feature_transform(X[b]) # 特征转换
|
||||||
filt = self.feature_transform(X[b]) # [N, embed_dim]
|
nodevec = torch.tanh(self.node_embeddings * filt) # 计算节点嵌入
|
||||||
|
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # 计算邻接矩阵
|
||||||
# 计算节点嵌入向量
|
laplacian = self.get_laplacian(adj, I) # 计算拉普拉斯矩阵
|
||||||
nodevec = torch.tanh(self.node_embeddings * filt)
|
|
||||||
|
|
||||||
# 通过节点嵌入的点积计算邻接矩阵
|
|
||||||
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1)))
|
|
||||||
|
|
||||||
# 计算归一化拉普拉斯矩阵
|
|
||||||
laplacian = self.get_laplacian(adj, I)
|
|
||||||
laplacians.append(laplacian)
|
laplacians.append(laplacian)
|
||||||
|
return torch.stack(laplacians, dim=0) # 堆叠并返回
|
||||||
return torch.stack(laplacians, dim=0)
|
|
||||||
|
|
||||||
class GraphEnhancedEncoder(nn.Module):
|
class GraphEnhancedEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
@ -190,8 +164,8 @@ class ASTRA(nn.Module):
|
||||||
# 添加动态图增强编码器
|
# 添加动态图增强编码器
|
||||||
self.graph_encoder = GraphEnhancedEncoder(
|
self.graph_encoder = GraphEnhancedEncoder(
|
||||||
K=configs.get('chebyshev_order', 3),
|
K=configs.get('chebyshev_order', 3),
|
||||||
in_dim=self.d_model,
|
in_dim=self.d_model * self.input_dim,
|
||||||
hidden_dim=configs.get('graph_hidden_dim', 32),
|
hidden_dim=self.d_model,
|
||||||
num_nodes=self.num_nodes,
|
num_nodes=self.num_nodes,
|
||||||
embed_dim=configs.get('graph_embed_dim', 10),
|
embed_dim=configs.get('graph_embed_dim', 10),
|
||||||
device=self.device
|
device=self.device
|
||||||
|
|
@ -199,14 +173,14 @@ class ASTRA(nn.Module):
|
||||||
|
|
||||||
# 特征融合层
|
# 特征融合层
|
||||||
self.feature_fusion = nn.Linear(
|
self.feature_fusion = nn.Linear(
|
||||||
self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1),
|
self.d_model * self.input_dim + self.d_model * (configs.get('chebyshev_order', 3) + 1),
|
||||||
self.d_model
|
self.d_model
|
||||||
)
|
)
|
||||||
|
|
||||||
self.out_mlp = nn.Sequential(
|
self.out_mlp = nn.Sequential(
|
||||||
nn.Linear(self.d_llm, 128),
|
nn.Linear(self.d_llm, 128),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(128, self.pred_len)
|
nn.Linear(128, self.pred_len * self.output_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, (name, param) in enumerate(self.gpts.named_parameters()):
|
for i, (name, param) in enumerate(self.gpts.named_parameters()):
|
||||||
|
|
@ -229,9 +203,9 @@ class ASTRA(nn.Module):
|
||||||
x = x[..., :self.input_dim]
|
x = x[..., :self.input_dim]
|
||||||
x_enc = rearrange(x, 'b t n c -> b n c t')
|
x_enc = rearrange(x, 'b t n c -> b n c t')
|
||||||
# 原版Patch
|
# 原版Patch
|
||||||
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C)
|
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, d_model * input_dim)
|
||||||
# 应用图增强编码器(自动生成图结构)
|
# 应用图增强编码器(自动生成图结构)
|
||||||
graph_enhanced = self.graph_encoder(enc_out)
|
graph_enhanced = self.graph_encoder(enc_out) # (B, N, K * hidden_dim)
|
||||||
# 特征融合 - 现在两个张量都是三维的 [B, N, d_model]
|
# 特征融合 - 现在两个张量都是三维的 [B, N, d_model]
|
||||||
enc_out = torch.cat([enc_out, graph_enhanced], dim=-1)
|
enc_out = torch.cat([enc_out, graph_enhanced], dim=-1)
|
||||||
enc_out = self.feature_fusion(enc_out)
|
enc_out = self.feature_fusion(enc_out)
|
||||||
|
|
@ -243,9 +217,10 @@ class ASTRA(nn.Module):
|
||||||
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
||||||
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
|
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
|
||||||
|
|
||||||
dec_out = self.out_mlp(enc_out)
|
dec_out = self.out_mlp(enc_out) #[B, N, T*C]
|
||||||
outputs = dec_out.unsqueeze(dim=-1)
|
|
||||||
outputs = outputs.repeat(1, 1, 1, n_vars)
|
B, N, _ = dec_out.shape
|
||||||
outputs = outputs.permute(0,2,1,3)
|
outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
|
||||||
|
outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,7 @@ class ASTRA(nn.Module):
|
||||||
self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度
|
self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度
|
||||||
self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径
|
self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径
|
||||||
self.num_nodes = configs.get('num_nodes', 325) # 节点数量
|
self.num_nodes = configs.get('num_nodes', 325) # 节点数量
|
||||||
|
self.output_dim = configs.get('output_dim', 1)
|
||||||
|
|
||||||
self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层
|
self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层
|
||||||
|
|
||||||
|
|
@ -169,7 +170,7 @@ class ASTRA(nn.Module):
|
||||||
self.out_mlp = nn.Sequential(
|
self.out_mlp = nn.Sequential(
|
||||||
nn.Linear(self.d_llm, 128),
|
nn.Linear(self.d_llm, 128),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(128, self.pred_len)
|
nn.Linear(128, self.pred_len * self.output_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置参数可训练性 wps=word position embeddings
|
# 设置参数可训练性 wps=word position embeddings
|
||||||
|
|
@ -202,9 +203,8 @@ class ASTRA(nn.Module):
|
||||||
dec_out = self.out_mlp(enc_out) # [B,N,pred_len]
|
dec_out = self.out_mlp(enc_out) # [B,N,pred_len]
|
||||||
|
|
||||||
# 维度调整
|
# 维度调整
|
||||||
dec_out = self.out_mlp(enc_out)
|
B, N, _ = dec_out.shape
|
||||||
outputs = dec_out.unsqueeze(dim=-1)
|
outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
|
||||||
outputs = outputs.repeat(1, 1, 1, self.input_dim)
|
outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
|
||||||
outputs = outputs.permute(0,2,1,3)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
@ -128,6 +128,7 @@ class ASTRA(nn.Module):
|
||||||
self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度
|
self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度
|
||||||
self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径
|
self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径
|
||||||
self.num_nodes = configs.get('num_nodes', 325) # 节点数量
|
self.num_nodes = configs.get('num_nodes', 325) # 节点数量
|
||||||
|
self.output_dim = configs.get('output_dim', 1)
|
||||||
|
|
||||||
self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层
|
self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层
|
||||||
|
|
||||||
|
|
@ -169,7 +170,7 @@ class ASTRA(nn.Module):
|
||||||
self.out_mlp = nn.Sequential(
|
self.out_mlp = nn.Sequential(
|
||||||
nn.Linear(self.d_llm, 128),
|
nn.Linear(self.d_llm, 128),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(128, self.pred_len)
|
nn.Linear(128, self.pred_len * self.output_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置参数可训练性 wps=word position embeddings
|
# 设置参数可训练性 wps=word position embeddings
|
||||||
|
|
@ -203,9 +204,8 @@ class ASTRA(nn.Module):
|
||||||
dec_out = self.out_mlp(X_enc) # [B,N,pred_len]
|
dec_out = self.out_mlp(X_enc) # [B,N,pred_len]
|
||||||
|
|
||||||
# 维度调整
|
# 维度调整
|
||||||
dec_out = self.out_mlp(enc_out)
|
B, N, _ = dec_out.shape
|
||||||
outputs = dec_out.unsqueeze(dim=-1)
|
outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
|
||||||
outputs = outputs.repeat(1, 1, 1, self.input_dim)
|
outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
|
||||||
outputs = outputs.permute(0,2,1,3)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
Loading…
Reference in New Issue