49 lines
1.6 KiB
Python
Executable File
49 lines
1.6 KiB
Python
Executable File
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
class EXP(nn.Module):
|
||
"""
|
||
高效的多步预测模型:
|
||
- 输入 x: (B, T, N, D_total),只使用主观测通道 x[...,0]
|
||
- 对每个节点的序列 x[b,:,n] (长度 T) 通过 shared MLP 编码
|
||
- 最后映射到 horizon * output_dim,并重塑为 (B, horizon, N, output_dim)
|
||
"""
|
||
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
self.horizon = args["horizon"]
|
||
self.output_dim = args["output_dim"]
|
||
# 隐层维度,可调整
|
||
hidden_dim = args.get("hidden_dim", 128)
|
||
T = 12
|
||
self.encoder = nn.Sequential(
|
||
nn.Linear(in_features=T, out_features=hidden_dim),
|
||
nn.ReLU(),
|
||
nn.Dropout(0.1),
|
||
)
|
||
# decoder 将 hidden_dim -> horizon * output_dim
|
||
self.decoder = nn.Linear(hidden_dim, self.horizon * self.output_dim)
|
||
|
||
def forward(self, x):
|
||
# x: (B, T, N, D_total)
|
||
# 1) 只取主观测通道
|
||
x_main = x[..., 0] # (B, T, N)
|
||
B, T, N = x_main.shape
|
||
|
||
# 2) 重排并展开:每个节点的序列当作一个样本
|
||
# (B, T, N) -> (B, N, T) -> (B*N, T)
|
||
h_in = x_main.permute(0, 2, 1).reshape(B * N, T)
|
||
|
||
# 3) shared MLP 编码
|
||
h = self.encoder(h_in) # (B*N, hidden_dim)
|
||
|
||
# 4) 解码到所有步预测
|
||
out_flat = self.decoder(h) # (B*N, horizon * output_dim)
|
||
|
||
# 5) 重塑回 (B, horizon, N, output_dim)
|
||
out = out_flat.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3)
|
||
|
||
return out
|