import torch import torch.nn as nn import torch.nn.functional as F class TemporalBlock(nn.Module): """ TCN 中的因果残差块,对每个节点的时间序列进行因果卷积, 保证输出长度与输入一致。 """ def __init__(self, in_channels, out_channels, kernel_size, dilation, dropout=0.1): super().__init__() self.kernel_size = kernel_size self.dilation = dilation # 填充长度= (kernel_size-1)*dilation self.padding = (kernel_size - 1) * dilation # 因果卷积:在 forward 里自己做 pad,不在这里传 padding 参数 self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=0, dilation=dilation) self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=0, dilation=dilation) # 如果通道数要变,则用 1×1 做下采样;否则直接残差 self.downsample = (nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None) self.dropout = nn.Dropout(dropout) self.relu = nn.ReLU() self.norm = nn.LayerNorm(out_channels) def forward(self, x): # x: (B*N, C_in, T) # 1) 因果填充:在时间维度左侧 pad x_padded = F.pad(x, (self.padding, 0)) # pad=(left, right) # 2) 第一层卷积 out = self.conv1(x_padded) # (B*N, C_out, T + padding) out = self.relu(out) out = self.dropout(out) # 3) 第二层卷积,同样先 pad out = F.pad(out, (self.padding, 0)) out = self.conv2(out) # (B*N, C_out, T + padding) out = self.dropout(out) # 4) 残差分支 res = x if self.downsample is None else self.downsample(x) # (B*N, C_out, T) # 5) 截掉多余的前面 padding,取最后 T 个时间点 out = out[..., -x.size(2):] # now out.shape == res.shape # 6) 残差相加 + LayerNorm + ReLU return self.relu(self.norm((out + res).permute(0, 2, 1))).permute(0, 2, 1) class EXP(nn.Module): """ 时空混合模型: 1. 对每个节点的长度-T 序列,用 TCN 提取时间特征; 2. 取 TCN 最后时刻的隐藏,重组为 (B, N, hidden_dim); 3. 用 Spatial Self‑Attention 在节点维度上捕捉空间依赖; 4. 最后一个 Linear 将每个节点的特征映射到 horizon 步预测。 """ def __init__(self, args): super().__init__() self.seq_len = args.get('in_len', 12) # 输入序列长度 T self.horizon = args['horizon'] self.output_dim = args['output_dim'] hidden_dim = args.get('hidden_dim', 64) tcn_layers = args.get('tcn_layers', 3) kernel_size = args.get('kernel_size', 3) dropout = args.get('dropout', 0.1) nhead = args.get('nhead', 4) # ----- Temporal Convolutional Network ----- tcn_blocks = [] in_ch = 1 # 只用主观测通道 for i in range(tcn_layers): dilation = 2 ** i out_ch = hidden_dim tcn_blocks.append( TemporalBlock(in_ch, out_ch, kernel_size, dilation, dropout) ) in_ch = out_ch self.tcn = nn.Sequential(*tcn_blocks) # ----- Spatial Self-Attention ----- # 我们把节点看作 tokens,特征维度 hidden_dim # MultiheadAttention 要求输入 (S, B, E),这里 S = N self.spatial_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=nhead, dropout=dropout, batch_first=False) # 可选的 LayerNorm self.norm_spatial = nn.LayerNorm(hidden_dim) # ----- Decoder ----- # hidden_dim -> horizon * output_dim self.decoder = nn.Linear(hidden_dim, self.horizon * self.output_dim) def forward(self, x): """ x: (B, T, N, D_total),只用第0通道 returns: (B, horizon, N, output_dim) """ B, T, N, D_total = x.shape assert T == self.seq_len, f"Expected T={self.seq_len}, got {T}" # 1) 取主观测、并重排给 TCN x_main = x[..., 0] # (B, T, N) x_tcn = x_main.reshape(B * N, 1, T) # (B*N, 1, T) # 2) TCN 提取时间特征 tcn_out = self.tcn(x_tcn) # (B*N, hidden_dim, T) # 3) 取最后时刻特征 last = tcn_out[:, :, -1] # (B*N, hidden_dim) h = last.view(B, N, -1) # (B, N, hidden_dim) # 4) Spatial Attention # 调整为 (N, B, E) 以供 MultiheadAttention h2 = h.permute(1, 0, 2) # (N, B, hidden_dim) attn_out, _ = self.spatial_attn(h2, h2, h2) # (N, B, hidden_dim) attn_out = attn_out.permute(1, 0, 2) # (B, N, hidden_dim) h_spatial = self.norm_spatial(attn_out + h) # 残差 + LayerNorm # 5) Decoder: 每个节点映射到 horizon*output_dim flat = h_spatial.reshape(B * N, -1) # (B*N, hidden_dim) out_flat = self.decoder(flat) # (B*N, horizon*output_dim) # 6) 重塑为 (B, horizon, N, output_dim) out = out_flat.view(B, N, self.horizon, self.output_dim) \ .permute(0, 2, 1, 3) return out