import torch import torch.nn as nn import torch.nn.functional as F class EXP(nn.Module): """ 高效的多步预测模型: - 输入 x: (B, T, N, D_total),只使用主观测通道 x[...,0] - 对每个节点的序列 x[b,:,n] (长度 T) 通过 shared MLP 编码 - 最后映射到 horizon * output_dim,并重塑为 (B, horizon, N, output_dim) """ def __init__(self, args): super().__init__() self.horizon = args["horizon"] self.output_dim = args["output_dim"] # 隐层维度,可调整 hidden_dim = args.get("hidden_dim", 128) T = 12 self.encoder = nn.Sequential( nn.Linear(in_features=T, out_features=hidden_dim), nn.ReLU(), nn.Dropout(0.1), ) # decoder 将 hidden_dim -> horizon * output_dim self.decoder = nn.Linear(hidden_dim, self.horizon * self.output_dim) def forward(self, x): # x: (B, T, N, D_total) # 1) 只取主观测通道 x_main = x[..., 0] # (B, T, N) B, T, N = x_main.shape # 2) 重排并展开:每个节点的序列当作一个样本 # (B, T, N) -> (B, N, T) -> (B*N, T) h_in = x_main.permute(0, 2, 1).reshape(B * N, T) # 3) shared MLP 编码 h = self.encoder(h_in) # (B*N, hidden_dim) # 4) 解码到所有步预测 out_flat = self.decoder(h) # (B*N, horizon * output_dim) # 5) 重塑回 (B, horizon, N, output_dim) out = out_flat.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3) return out