REPST #3

Merged
czzhangheng merged 42 commits from REPST into main 2025-12-20 16:03:22 +08:00
7 changed files with 43 additions and 65 deletions
Showing only changes of commit b38e4a5da2 - Show all commits

View File

@ -32,6 +32,7 @@ model:
seq_len: 24 seq_len: 24
stride: 7 stride: 7
word_num: 1000 word_num: 1000
output_dim: 6
train: train:
batch_size: 16 batch_size: 16

View File

@ -32,6 +32,7 @@ model:
seq_len: 24 seq_len: 24
stride: 7 stride: 7
word_num: 1000 word_num: 1000
output_dim: 6
train: train:
batch_size: 16 batch_size: 16

View File

@ -32,6 +32,7 @@ model:
seq_len: 24 seq_len: 24
stride: 7 stride: 7
word_num: 1000 word_num: 1000
output_dim: 6
train: train:
batch_size: 16 batch_size: 16

View File

@ -10,7 +10,7 @@ data:
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
input_dim: 137 input_dim: 1
lag: 24 lag: 24
normalizer: std normalizer: std
num_nodes: 137 num_nodes: 137

View File

@ -7,22 +7,15 @@ from model.ASTRA.reprogramming import PatchEmbedding, ReprogrammingLayer
import torch.nn.functional as F import torch.nn.functional as F
class DynamicGraphEnhancer(nn.Module): class DynamicGraphEnhancer(nn.Module):
""" """动态图增强编码器"""
动态图增强器基于节点嵌入自动生成图结构
参考DDGCRN的设计使用节点嵌入和特征信息动态计算邻接矩阵
"""
def __init__(self, num_nodes, in_dim, embed_dim=10): def __init__(self, num_nodes, in_dim, embed_dim=10):
super().__init__() super().__init__()
self.num_nodes = num_nodes self.num_nodes = num_nodes # 节点个数
self.embed_dim = embed_dim self.embed_dim = embed_dim # 节点嵌入维度
# 节点嵌入参数 self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) # 节点嵌入参数
self.node_embeddings = nn.Parameter(
torch.randn(num_nodes, embed_dim), requires_grad=True
)
# 特征转换层,用于生成动态调整的嵌入 self.feature_transform = nn.Sequential( # 特征转换网络
self.feature_transform = nn.Sequential(
nn.Linear(in_dim, 16), nn.Linear(in_dim, 16),
nn.Sigmoid(), nn.Sigmoid(),
nn.Linear(16, 2), nn.Linear(16, 2),
@ -30,48 +23,29 @@ class DynamicGraphEnhancer(nn.Module):
nn.Linear(2, embed_dim) nn.Linear(2, embed_dim)
) )
# 注册单位矩阵作为固定的支持矩阵 self.register_buffer("eye", torch.eye(num_nodes)) # 注册单位矩阵
self.register_buffer("eye", torch.eye(num_nodes))
def get_laplacian(self, graph, I, normalize=True): def get_laplacian(self, graph, I, normalize=True):
""" D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) # 度矩阵的逆平方根
计算归一化拉普拉斯矩阵
"""
# 计算度矩阵的逆平方根
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题 D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题
if normalize: if normalize:
return torch.matmul(torch.matmul(D_inv, graph), D_inv) return torch.matmul(torch.matmul(D_inv, graph), D_inv) # 归一化拉普拉斯矩阵
else: else:
return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) # 带自环的归一化拉普拉斯矩阵
def forward(self, X): def forward(self, X):
""" """生成动态拉普拉斯矩阵"""
X: 输入特征 [B, N, D] batch_size = X.size(0) # 批次大小
返回: 动态生成的归一化拉普拉斯矩阵 [B, N, N] laplacians = [] # 存储各批次的拉普拉斯矩阵
""" I = self.eye.to(X.device) # 移动单位矩阵到目标设备
batch_size = X.size(0)
laplacians = []
# 获取单位矩阵
I = self.eye.to(X.device)
for b in range(batch_size): for b in range(batch_size):
# 使用特征转换层生成动态嵌入调整因子 filt = self.feature_transform(X[b]) # 特征转换
filt = self.feature_transform(X[b]) # [N, embed_dim] nodevec = torch.tanh(self.node_embeddings * filt) # 计算节点嵌入
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # 计算邻接矩阵
# 计算节点嵌入向量 laplacian = self.get_laplacian(adj, I) # 计算拉普拉斯矩阵
nodevec = torch.tanh(self.node_embeddings * filt)
# 通过节点嵌入的点积计算邻接矩阵
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1)))
# 计算归一化拉普拉斯矩阵
laplacian = self.get_laplacian(adj, I)
laplacians.append(laplacian) laplacians.append(laplacian)
return torch.stack(laplacians, dim=0) # 堆叠并返回
return torch.stack(laplacians, dim=0)
class GraphEnhancedEncoder(nn.Module): class GraphEnhancedEncoder(nn.Module):
""" """
@ -190,8 +164,8 @@ class ASTRA(nn.Module):
# 添加动态图增强编码器 # 添加动态图增强编码器
self.graph_encoder = GraphEnhancedEncoder( self.graph_encoder = GraphEnhancedEncoder(
K=configs.get('chebyshev_order', 3), K=configs.get('chebyshev_order', 3),
in_dim=self.d_model, in_dim=self.d_model * self.input_dim,
hidden_dim=configs.get('graph_hidden_dim', 32), hidden_dim=self.d_model,
num_nodes=self.num_nodes, num_nodes=self.num_nodes,
embed_dim=configs.get('graph_embed_dim', 10), embed_dim=configs.get('graph_embed_dim', 10),
device=self.device device=self.device
@ -199,14 +173,14 @@ class ASTRA(nn.Module):
# 特征融合层 # 特征融合层
self.feature_fusion = nn.Linear( self.feature_fusion = nn.Linear(
self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), self.d_model * self.input_dim + self.d_model * (configs.get('chebyshev_order', 3) + 1),
self.d_model self.d_model
) )
self.out_mlp = nn.Sequential( self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128), nn.Linear(self.d_llm, 128),
nn.ReLU(), nn.ReLU(),
nn.Linear(128, self.pred_len) nn.Linear(128, self.pred_len * self.output_dim)
) )
for i, (name, param) in enumerate(self.gpts.named_parameters()): for i, (name, param) in enumerate(self.gpts.named_parameters()):
@ -229,9 +203,9 @@ class ASTRA(nn.Module):
x = x[..., :self.input_dim] x = x[..., :self.input_dim]
x_enc = rearrange(x, 'b t n c -> b n c t') x_enc = rearrange(x, 'b t n c -> b n c t')
# 原版Patch # 原版Patch
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C) enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, d_model * input_dim)
# 应用图增强编码器(自动生成图结构) # 应用图增强编码器(自动生成图结构)
graph_enhanced = self.graph_encoder(enc_out) graph_enhanced = self.graph_encoder(enc_out) # (B, N, K * hidden_dim)
# 特征融合 - 现在两个张量都是三维的 [B, N, d_model] # 特征融合 - 现在两个张量都是三维的 [B, N, d_model]
enc_out = torch.cat([enc_out, graph_enhanced], dim=-1) enc_out = torch.cat([enc_out, graph_enhanced], dim=-1)
enc_out = self.feature_fusion(enc_out) enc_out = self.feature_fusion(enc_out)
@ -243,9 +217,10 @@ class ASTRA(nn.Module):
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
dec_out = self.out_mlp(enc_out) dec_out = self.out_mlp(enc_out) #[B, N, T*C]
outputs = dec_out.unsqueeze(dim=-1)
outputs = outputs.repeat(1, 1, 1, n_vars) B, N, _ = dec_out.shape
outputs = outputs.permute(0,2,1,3) outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
return outputs return outputs

View File

@ -128,6 +128,7 @@ class ASTRA(nn.Module):
self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度 self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度
self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径 self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径
self.num_nodes = configs.get('num_nodes', 325) # 节点数量 self.num_nodes = configs.get('num_nodes', 325) # 节点数量
self.output_dim = configs.get('output_dim', 1)
self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层 self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层
@ -169,7 +170,7 @@ class ASTRA(nn.Module):
self.out_mlp = nn.Sequential( self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128), nn.Linear(self.d_llm, 128),
nn.ReLU(), nn.ReLU(),
nn.Linear(128, self.pred_len) nn.Linear(128, self.pred_len * self.output_dim)
) )
# 设置参数可训练性 wps=word position embeddings # 设置参数可训练性 wps=word position embeddings
@ -202,9 +203,8 @@ class ASTRA(nn.Module):
dec_out = self.out_mlp(enc_out) # [B,N,pred_len] dec_out = self.out_mlp(enc_out) # [B,N,pred_len]
# 维度调整 # 维度调整
dec_out = self.out_mlp(enc_out) B, N, _ = dec_out.shape
outputs = dec_out.unsqueeze(dim=-1) outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
outputs = outputs.repeat(1, 1, 1, self.input_dim) outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
outputs = outputs.permute(0,2,1,3)
return outputs return outputs

View File

@ -128,6 +128,7 @@ class ASTRA(nn.Module):
self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度 self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度
self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径 self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径
self.num_nodes = configs.get('num_nodes', 325) # 节点数量 self.num_nodes = configs.get('num_nodes', 325) # 节点数量
self.output_dim = configs.get('output_dim', 1)
self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层 self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层
@ -169,7 +170,7 @@ class ASTRA(nn.Module):
self.out_mlp = nn.Sequential( self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128), nn.Linear(self.d_llm, 128),
nn.ReLU(), nn.ReLU(),
nn.Linear(128, self.pred_len) nn.Linear(128, self.pred_len * self.output_dim)
) )
# 设置参数可训练性 wps=word position embeddings # 设置参数可训练性 wps=word position embeddings
@ -203,9 +204,8 @@ class ASTRA(nn.Module):
dec_out = self.out_mlp(X_enc) # [B,N,pred_len] dec_out = self.out_mlp(X_enc) # [B,N,pred_len]
# 维度调整 # 维度调整
dec_out = self.out_mlp(enc_out) B, N, _ = dec_out.shape
outputs = dec_out.unsqueeze(dim=-1) outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
outputs = outputs.repeat(1, 1, 1, self.input_dim) outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
outputs = outputs.permute(0,2,1,3)
return outputs return outputs