REPST #3

Merged
czzhangheng merged 42 commits from REPST into main 2025-12-20 16:03:22 +08:00
7 changed files with 43 additions and 65 deletions
Showing only changes of commit b38e4a5da2 - Show all commits

View File

@ -32,6 +32,7 @@ model:
seq_len: 24
stride: 7
word_num: 1000
output_dim: 6
train:
batch_size: 16

View File

@ -32,6 +32,7 @@ model:
seq_len: 24
stride: 7
word_num: 1000
output_dim: 6
train:
batch_size: 16

View File

@ -32,6 +32,7 @@ model:
seq_len: 24
stride: 7
word_num: 1000
output_dim: 6
train:
batch_size: 16

View File

@ -10,7 +10,7 @@ data:
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 137
input_dim: 1
lag: 24
normalizer: std
num_nodes: 137

View File

@ -7,22 +7,15 @@ from model.ASTRA.reprogramming import PatchEmbedding, ReprogrammingLayer
import torch.nn.functional as F
class DynamicGraphEnhancer(nn.Module):
"""
动态图增强器基于节点嵌入自动生成图结构
参考DDGCRN的设计使用节点嵌入和特征信息动态计算邻接矩阵
"""
"""动态图增强编码器"""
def __init__(self, num_nodes, in_dim, embed_dim=10):
super().__init__()
self.num_nodes = num_nodes
self.embed_dim = embed_dim
self.num_nodes = num_nodes # 节点个数
self.embed_dim = embed_dim # 节点嵌入维度
# 节点嵌入参数
self.node_embeddings = nn.Parameter(
torch.randn(num_nodes, embed_dim), requires_grad=True
)
self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) # 节点嵌入参数
# 特征转换层,用于生成动态调整的嵌入
self.feature_transform = nn.Sequential(
self.feature_transform = nn.Sequential( # 特征转换网络
nn.Linear(in_dim, 16),
nn.Sigmoid(),
nn.Linear(16, 2),
@ -30,48 +23,29 @@ class DynamicGraphEnhancer(nn.Module):
nn.Linear(2, embed_dim)
)
# 注册单位矩阵作为固定的支持矩阵
self.register_buffer("eye", torch.eye(num_nodes))
self.register_buffer("eye", torch.eye(num_nodes)) # 注册单位矩阵
def get_laplacian(self, graph, I, normalize=True):
"""
计算归一化拉普拉斯矩阵
"""
# 计算度矩阵的逆平方根
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) # 度矩阵的逆平方根
D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题
if normalize:
return torch.matmul(torch.matmul(D_inv, graph), D_inv)
return torch.matmul(torch.matmul(D_inv, graph), D_inv) # 归一化拉普拉斯矩阵
else:
return torch.matmul(torch.matmul(D_inv, graph + I), D_inv)
return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) # 带自环的归一化拉普拉斯矩阵
def forward(self, X):
"""
X: 输入特征 [B, N, D]
返回: 动态生成的归一化拉普拉斯矩阵 [B, N, N]
"""
batch_size = X.size(0)
laplacians = []
# 获取单位矩阵
I = self.eye.to(X.device)
"""生成动态拉普拉斯矩阵"""
batch_size = X.size(0) # 批次大小
laplacians = [] # 存储各批次的拉普拉斯矩阵
I = self.eye.to(X.device) # 移动单位矩阵到目标设备
for b in range(batch_size):
# 使用特征转换层生成动态嵌入调整因子
filt = self.feature_transform(X[b]) # [N, embed_dim]
# 计算节点嵌入向量
nodevec = torch.tanh(self.node_embeddings * filt)
# 通过节点嵌入的点积计算邻接矩阵
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1)))
# 计算归一化拉普拉斯矩阵
laplacian = self.get_laplacian(adj, I)
filt = self.feature_transform(X[b]) # 特征转换
nodevec = torch.tanh(self.node_embeddings * filt) # 计算节点嵌入
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # 计算邻接矩阵
laplacian = self.get_laplacian(adj, I) # 计算拉普拉斯矩阵
laplacians.append(laplacian)
return torch.stack(laplacians, dim=0)
return torch.stack(laplacians, dim=0) # 堆叠并返回
class GraphEnhancedEncoder(nn.Module):
"""
@ -190,8 +164,8 @@ class ASTRA(nn.Module):
# 添加动态图增强编码器
self.graph_encoder = GraphEnhancedEncoder(
K=configs.get('chebyshev_order', 3),
in_dim=self.d_model,
hidden_dim=configs.get('graph_hidden_dim', 32),
in_dim=self.d_model * self.input_dim,
hidden_dim=self.d_model,
num_nodes=self.num_nodes,
embed_dim=configs.get('graph_embed_dim', 10),
device=self.device
@ -199,14 +173,14 @@ class ASTRA(nn.Module):
# 特征融合层
self.feature_fusion = nn.Linear(
self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1),
self.d_model * self.input_dim + self.d_model * (configs.get('chebyshev_order', 3) + 1),
self.d_model
)
self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128),
nn.ReLU(),
nn.Linear(128, self.pred_len)
nn.Linear(128, self.pred_len * self.output_dim)
)
for i, (name, param) in enumerate(self.gpts.named_parameters()):
@ -229,9 +203,9 @@ class ASTRA(nn.Module):
x = x[..., :self.input_dim]
x_enc = rearrange(x, 'b t n c -> b n c t')
# 原版Patch
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C)
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, d_model * input_dim)
# 应用图增强编码器(自动生成图结构)
graph_enhanced = self.graph_encoder(enc_out)
graph_enhanced = self.graph_encoder(enc_out) # (B, N, K * hidden_dim)
# 特征融合 - 现在两个张量都是三维的 [B, N, d_model]
enc_out = torch.cat([enc_out, graph_enhanced], dim=-1)
enc_out = self.feature_fusion(enc_out)
@ -243,9 +217,10 @@ class ASTRA(nn.Module):
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
dec_out = self.out_mlp(enc_out)
outputs = dec_out.unsqueeze(dim=-1)
outputs = outputs.repeat(1, 1, 1, n_vars)
outputs = outputs.permute(0,2,1,3)
dec_out = self.out_mlp(enc_out) #[B, N, T*C]
B, N, _ = dec_out.shape
outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
return outputs

View File

@ -128,6 +128,7 @@ class ASTRA(nn.Module):
self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度
self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径
self.num_nodes = configs.get('num_nodes', 325) # 节点数量
self.output_dim = configs.get('output_dim', 1)
self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层
@ -169,7 +170,7 @@ class ASTRA(nn.Module):
self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128),
nn.ReLU(),
nn.Linear(128, self.pred_len)
nn.Linear(128, self.pred_len * self.output_dim)
)
# 设置参数可训练性 wps=word position embeddings
@ -202,9 +203,8 @@ class ASTRA(nn.Module):
dec_out = self.out_mlp(enc_out) # [B,N,pred_len]
# 维度调整
dec_out = self.out_mlp(enc_out)
outputs = dec_out.unsqueeze(dim=-1)
outputs = outputs.repeat(1, 1, 1, self.input_dim)
outputs = outputs.permute(0,2,1,3)
B, N, _ = dec_out.shape
outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
return outputs

View File

@ -128,6 +128,7 @@ class ASTRA(nn.Module):
self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度
self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径
self.num_nodes = configs.get('num_nodes', 325) # 节点数量
self.output_dim = configs.get('output_dim', 1)
self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层
@ -169,7 +170,7 @@ class ASTRA(nn.Module):
self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128),
nn.ReLU(),
nn.Linear(128, self.pred_len)
nn.Linear(128, self.pred_len * self.output_dim)
)
# 设置参数可训练性 wps=word position embeddings
@ -203,9 +204,8 @@ class ASTRA(nn.Module):
dec_out = self.out_mlp(X_enc) # [B,N,pred_len]
# 维度调整
dec_out = self.out_mlp(enc_out)
outputs = dec_out.unsqueeze(dim=-1)
outputs = outputs.repeat(1, 1, 1, self.input_dim)
outputs = outputs.permute(0,2,1,3)
B, N, _ = dec_out.shape
outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
return outputs