import torch import torch.nn as nn import torch.nn.functional as F class ResidualMLPBlock(nn.Module): """ 一个隐藏维度下的残差块: x -> Linear(hidden->hidden) -> ReLU -> Dropout -> Linear(hidden->hidden) -> Dropout + 残差跳连 -> LayerNorm """ def __init__(self, hidden_dim, dropout=0.1): super().__init__() self.fc1 = nn.Linear(hidden_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.drop = nn.Dropout(dropout) self.norm = nn.LayerNorm(hidden_dim) def forward(self, x): resid = x x = F.relu(self.fc1(x)) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return self.norm(x + resid) class EXP(nn.Module): """ 带残差连接的多层 MLP 预测模型: - 输入 x: (B, T, N, D_total),使用 x[...,0]。 - seq_len=T 的序列先投影到 hidden_dim, 再经过 num_blocks 个 ResidualMLPBlock。 - 最后投影到 horizon * output_dim,重塑为 (B, horizon, N, output_dim)。 """ def __init__(self, args): super().__init__() self.horizon = args["horizon"] self.output_dim = args["output_dim"] self.seq_len = args.get("in_len", 12) # 序列长度 T,默认 12 hidden_dim = args.get("hidden_dim", 64) num_blocks = args.get("num_mlp_layers", 2) dropout = args.get("dropout", 0.1) # 1) 输入投影:T -> hidden_dim self.input_proj = nn.Linear(self.seq_len, hidden_dim) self.input_drop = nn.Dropout(dropout) # 2) 残差 MLP 块 self.blocks = nn.ModuleList( [ResidualMLPBlock(hidden_dim, dropout=dropout) for _ in range(num_blocks)] ) # 3) 输出投影:hidden_dim -> horizon * output_dim self.decoder = nn.Linear(hidden_dim, self.horizon * self.output_dim) def forward(self, x): # x: (B, T, N, D_total) x_main = x[..., 0] # (B, T, N) B, T, N = x_main.shape assert T == self.seq_len, f"期望序列长度 {self.seq_len}, 实际 {T}" # 每个节点的长度-T 序列作为独立样本 h_in = x_main.permute(0, 2, 1).reshape(B * N, T) # (B*N, T) # 1) 输入投影 + Dropout h = F.relu(self.input_proj(h_in)) # (B*N, hidden_dim) h = self.input_drop(h) # 2) 残差块堆叠 for block in self.blocks: h = block(h) # (B*N, hidden_dim) # 3) 解码到 horizon * output_dim out_flat = self.decoder(h) # (B*N, horizon * output_dim) # 4) 重塑为 (B, horizon, N, output_dim) out = out_flat.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3) return out