import math import torch import torch.nn as nn import torch.nn.functional as F class PositionalEncoding(nn.Module): """标准的位置编码,用于给 Transformer 输入添加位置信息""" def __init__(self, d_model, max_len=500): super().__init__() pe = torch.zeros(max_len, d_model) # (max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1).float() # (max_len,1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) # 偶数维 pe[:, 1::2] = torch.cos(position * div_term) # 奇数维 self.register_buffer('pe', pe) # 不参加梯度 def forward(self, x): # x: (T, B, d_model) T = x.size(0) return x + self.pe[:T].unsqueeze(1) # (T,1,d_model) 广播到 (T,B,d_model) class TemporalTransformerForecast(nn.Module): """ Transformer-based 多步预测: - 只使用 x[...,0] 作为输入通道 - 对每个节点的长度-T 序列并行应用 Transformer Encoder - 取最后时间步的输出,通过一个 Linear 映射到 horizon * output_dim - 重塑为 (B, horizon, N, output_dim) """ def __init__(self, args): super().__init__() self.horizon = args['horizon'] self.output_dim = args['output_dim'] self.seq_len = args.get('in_len', 12) assert self.seq_len is not None, "请在 args 中指定 in_len(输入序列长度)" d_model = args.get('d_model', 64) nhead = args.get('nhead', 4) num_layers = args.get('num_layers', 2) dim_ff = args.get('dim_feedforward', d_model * 4) dropout = args.get('dropout', 0.1) # 把单通道投影到 d_model self.input_proj = nn.Linear(1, d_model) self.pos_encoder = PositionalEncoding(d_model, max_len=self.seq_len) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, dropout=dropout, batch_first=False # 我们用 (T, B, D) 格式 ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # 最后一步输出到 多步预测 self.decoder = nn.Linear(d_model, self.horizon * self.output_dim) def forward(self, x): # x: (B, T, N, D_total) x_main = x[..., 0] # (B, T, N) B, T, N = x_main.shape assert T == self.seq_len, f"实际序列长度 {T} != 配置 in_len {self.seq_len}" # 重排:每个节点的序列是一个独立样本 # (B, T, N) -> (B*N, T, 1) seq = x_main.permute(0, 2, 1).reshape(B * N, T, 1) # 投影 & 位置编码 emb = self.input_proj(seq) # (B*N, T, d_model) emb = emb.permute(1, 0, 2) # -> (T, B*N, d_model) emb = self.pos_encoder(emb) # 加上位置信息 # Transformer Encoder out = self.transformer(emb) # (T, B*N, d_model) # 取最后时刻的隐藏向量 last = out[-1, :, :] # (B*N, d_model) # 解码为多步预测 pred_flat = self.decoder(last) # (B*N, horizon * output_dim) # 重塑回 (B, N, horizon, output_dim) -> (B, horizon, N, output_dim) pred = pred_flat.view(B, N, self.horizon, self.output_dim) \ .permute(0, 2, 1, 3) return pred