diff --git a/.vscode/launch.json b/.vscode/launch.json index 2b530ca..3dc2b03 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -234,6 +234,14 @@ "console": "integratedTerminal", "args": "--config ./config/AEPSA/v2_SolarEnergy.yaml" }, + { + "name": "AEPSA_v3: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AEPSA/v3_METR-LA.yaml" + }, { "name": "EXPB: NYCBike-InFlow", "type": "debugpy", diff --git a/config/AEPSA/v3_METR-LA.yaml b/config/AEPSA/v3_METR-LA.yaml new file mode 100644 index 0000000..5d22820 --- /dev/null +++ b/config/AEPSA/v3_METR-LA.yaml @@ -0,0 +1,57 @@ +basic: + dataset: METR-LA + device: cuda:0 + mode: train + model: AEPSA_v3 + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + chebyshev_order: 3 + d_ff: 128 + d_model: 64 + dropout: 0.2 + graph_hidden_dim: 32 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 207 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 diff --git a/config/AEPSA/v3_PEMS-BAY.yaml b/config/AEPSA/v3_PEMS-BAY.yaml new file mode 100755 index 0000000..9f98483 --- /dev/null +++ b/config/AEPSA/v3_PEMS-BAY.yaml @@ -0,0 +1,54 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: AEPSA_v3 + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 325 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/model/AEPSA/aepsav3.py b/model/AEPSA/aepsav3.py new file mode 100644 index 0000000..6a579b6 --- /dev/null +++ b/model/AEPSA/aepsav3.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn +from transformers.models.gpt2.modeling_gpt2 import GPT2Model +from einops import rearrange +from model.AEPSA.normalizer import GumbelSoftmax +from model.AEPSA.reprogramming import ReprogrammingLayer +import torch.nn.functional as F + +# 基于动态图增强的时空序列预测模型实现 + +class DynamicGraphEnhancer(nn.Module): + """动态图增强编码器""" + def __init__(self, num_nodes, in_dim, embed_dim=10): + super().__init__() + self.num_nodes = num_nodes # 节点个数 + self.embed_dim = embed_dim # 节点嵌入维度 + + self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) # 节点嵌入参数 + + self.feature_transform = nn.Sequential( # 特征转换网络 + nn.Linear(in_dim, 16), + nn.Sigmoid(), + nn.Linear(16, 2), + nn.Sigmoid(), + nn.Linear(2, embed_dim) + ) + + self.register_buffer("eye", torch.eye(num_nodes)) # 注册单位矩阵 + + def get_laplacian(self, graph, I, normalize=True): + D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) # 度矩阵的逆平方根 + D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题 + if normalize: + return torch.matmul(torch.matmul(D_inv, graph), D_inv) # 归一化拉普拉斯矩阵 + else: + return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) # 带自环的归一化拉普拉斯矩阵 + + def forward(self, X): + """生成动态拉普拉斯矩阵""" + batch_size = X.size(0) # 批次大小 + laplacians = [] # 存储各批次的拉普拉斯矩阵 + I = self.eye.to(X.device) # 移动单位矩阵到目标设备 + + for b in range(batch_size): + filt = self.feature_transform(X[b]) # 特征转换 + nodevec = torch.tanh(self.node_embeddings * filt) # 计算节点嵌入 + adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # 计算邻接矩阵 + laplacian = self.get_laplacian(adj, I) # 计算拉普拉斯矩阵 + laplacians.append(laplacian) + return torch.stack(laplacians, dim=0) # 堆叠并返回 + +class GraphEnhancedEncoder(nn.Module): + """图增强编码器""" + def __init__(self, K=3, in_dim=64, hidden_dim=32, num_nodes=325, embed_dim=10, device='cpu', + temporal_dim=12, num_features=1): + super().__init__() + self.K = K # Chebyshev多项式阶数 + self.in_dim = in_dim # 输入特征维度 + self.hidden_dim = hidden_dim # 隐藏层维度 + self.device = device # 运行设备 + self.temporal_dim = temporal_dim # 时间序列长度 + self.num_features = num_features # 特征通道数量 + + self.input_projection = nn.Sequential( # 输入投影层 + nn.Conv2d(num_features, 16, kernel_size=(1, 3), padding=(0, 1)), + nn.ReLU(), + nn.Conv2d(16, in_dim, kernel_size=(1, temporal_dim)), + nn.ReLU() + ) + + self.graph_enhancer = DynamicGraphEnhancer(num_nodes, in_dim, embed_dim) # 动态图增强器 + self.alpha = nn.Parameter(torch.randn(K + 1, 1)) # 谱系数 + self.W = nn.ParameterList([nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)]) # 传播权重 + self.to(device) # 移动到指定设备 + + def chebyshev_polynomials(self, L_tilde, X): + """计算Chebyshev多项式展开""" + T_k_list = [X] # T_0(X) = X + if self.K >= 1: + T_k_list.append(torch.matmul(L_tilde, X)) # T_1(X) = L_tilde * X + for k in range(2, self.K + 1): + T_k_list.append(2 * torch.matmul(L_tilde, T_k_list[-1]) - T_k_list[-2]) # 递推计算 + return T_k_list # 返回多项式列表 + + def forward(self, X): + """输入特征[B,N,C,T],返回增强特征[B,N,hidden_dim*(K+1)]""" + batch_size = X.size(0) # 批次大小 + num_nodes = X.size(1) # 节点数量 + + x = X.permute(0, 2, 1, 3) # [B,C,N,T] + x_proj = self.input_projection(x).squeeze(-1) # [B,in_dim,N] + x_proj = x_proj.permute(0, 2, 1) # [B,N,in_dim] + + enhanced_features = [] # 存储增强特征 + laplacians = self.graph_enhancer(x_proj) # 生成动态拉普拉斯矩阵 + + for b in range(batch_size): + L = laplacians[b] # 当前批次的拉普拉斯矩阵 + + # 特征值缩放 + try: + lambda_max = torch.linalg.eigvalsh(L).max().real # 最大特征值 + lambda_max = 1.0 if lambda_max < 1e-6 else lambda_max # 防止除零 + L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device) # 归一化拉普拉斯 + except: + L_tilde = torch.eye(num_nodes, device=X.device) # 异常处理 + + # 计算展开并应用权重 + T_k_list = self.chebyshev_polynomials(L_tilde, x_proj[b]) # 计算Chebyshev多项式 + H_list = [torch.matmul(T_k_list[k], self.W[k]) for k in range(self.K + 1)] # 应用权重 + X_enhanced = torch.cat(H_list, dim=-1) # 拼接特征 + enhanced_features.append(X_enhanced) + + return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)],每个节点在每个k阶下的切比雪夫特征 + +class AEPSA(nn.Module): + """自适应特征投影时空自注意力模型""" + def __init__(self, configs): + super(AEPSA, self).__init__() + self.device = configs['device'] # 运行设备 + self.pred_len = configs['pred_len'] # 预测序列长度 + self.seq_len = configs['seq_len'] # 输入序列长度 + self.patch_len = configs['patch_len'] # 补丁长度 + self.input_dim = configs['input_dim'] # 输入特征维度 + self.stride = configs['stride'] # 步长 + self.dropout = configs['dropout'] # Dropout概率 + self.gpt_layers = configs['gpt_layers'] # 使用的GPT2层数 + self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度 + self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径 + self.num_nodes = configs.get('num_nodes', 325) # 节点数量 + + self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层 + + self.d_model = configs['d_model'] # 模型维度 + self.n_heads = configs['n_heads'] # 注意力头数量 + self.d_keys = None # 键维度 + self.d_llm = 768 # GPT2隐藏层维度 + + self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2) # 补丁数量 + self.head_nf = self.d_ff * self.patch_nums # 头特征维度 + + # 初始化GPT2模型 + self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) # GPT2模型 + self.gpts.h = self.gpts.h[:self.gpt_layers] # 截取指定层数 + self.gpts.apply(self.reset_parameters) # 重置参数 + + self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) # 词嵌入权重 + self.vocab_size = self.word_embeddings.shape[0] # 词汇表大小 + self.mapping_layer = nn.Linear(self.vocab_size, 1) # 映射层 + self.reprogramming_layer = ReprogrammingLayer(self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), self.n_heads, self.d_keys, self.d_llm) # 重编程层 + + # 初始化图增强编码器 + self.graph_encoder = GraphEnhancedEncoder( + K=configs.get('chebyshev_order', 3), # Chebyshev多项式阶数 + in_dim=self.d_model, # 输入特征维度 + hidden_dim=configs.get('graph_hidden_dim', 32), # 隐藏层维度 + num_nodes=self.num_nodes, # 节点数量 + embed_dim=configs.get('graph_embed_dim', 10), # 节点嵌入维度 + device=self.device, # 运行设备 + temporal_dim=self.seq_len, # 时间序列长度 + num_features=self.input_dim # 特征通道数 + ) + + self.graph_projection = nn.Linear( # 图特征投影层,每一k阶的切比雪夫权重映射到隐藏维度 + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), # 输入维度 + self.d_model # 输出维度 + ) + + self.out_mlp = nn.Sequential( + nn.Linear(self.d_llm, 128), + nn.ReLU(), + nn.Linear(128, self.pred_len) + ) + + # 设置参数可训练性 wps=word position embeddings + for name, param in self.gpts.named_parameters(): + param.requires_grad = 'wpe' in name + + def reset_parameters(self, module): + if hasattr(module, 'weight') and module.weight is not None: + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if hasattr(module, 'bias') and module.bias is not None: + torch.nn.init.zeros_(module.bias) + + def forward(self, x): + # 数据处理 + x = x[..., :1] # [B,T,N,1] + x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T] + + # 图编码 + H_t = self.graph_encoder(x_enc) # [B,N,1,T] -> [B, N, hidden_dim*(K+1)] + X_t_1 = self.graph_projection(H_t) # [B,N,d_model] + enc_out = torch.cat([H_t, X_t_1], dim = -1) # [B, N, d_model + hidden_dim*(K+1)] + + # 词嵌入处理 + self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) + masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0)) # [d_llm,1] + source_embeddings = self.word_embeddings[masks==1] # [selected_words,d_llm] + + # 重编程与预测 + enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) + enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state # [B,N,d_llm] + dec_out = self.out_mlp(enc_out) # [B,N,pred_len] + + # 维度调整 + outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1] + outputs = outputs.permute(0, 2, 1, 3) # [B,pred_len,N,1] + + return outputs \ No newline at end of file diff --git a/model/model_selector.py b/model/model_selector.py index c669d82..633b02c 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -25,6 +25,7 @@ from model.STAWnet.STAWnet import STAWnet from model.REPST.repst import repst as REPST from model.AEPSA.aepsa import AEPSA as AEPSA from model.AEPSA.aepsav2 import AEPSA as AEPSAv2 +from model.AEPSA.aepsav3 import AEPSA as AEPSAv3 @@ -86,3 +87,5 @@ def model_selector(config): return AEPSA(model_config) case "AEPSA_v2": return AEPSAv2(model_config) + case "AEPSA_v3": + return AEPSAv3(model_config)