TrafficWheel/model/EXP/EXP3.py

48 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
class EXP(nn.Module):
"""
高效的多步预测模型:
- 输入 x: (B, T, N, D_total),只使用主观测通道 x[...,0]
- 对每个节点的序列 x[b,:,n] (长度 T) 通过 shared MLP 编码
- 最后映射到 horizon * output_dim并重塑为 (B, horizon, N, output_dim)
"""
def __init__(self, args):
super().__init__()
self.horizon = args['horizon']
self.output_dim = args['output_dim']
# 隐层维度,可调整
hidden_dim = args.get('hidden_dim', 128)
T = 12
self.encoder = nn.Sequential(
nn.Linear(in_features=T, out_features=hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
)
# decoder 将 hidden_dim -> horizon * output_dim
self.decoder = nn.Linear(hidden_dim, self.horizon * self.output_dim)
def forward(self, x):
# x: (B, T, N, D_total)
# 1) 只取主观测通道
x_main = x[..., 0] # (B, T, N)
B, T, N = x_main.shape
# 2) 重排并展开:每个节点的序列当作一个样本
# (B, T, N) -> (B, N, T) -> (B*N, T)
h_in = x_main.permute(0, 2, 1).reshape(B * N, T)
# 3) shared MLP 编码
h = self.encoder(h_in) # (B*N, hidden_dim)
# 4) 解码到所有步预测
out_flat = self.decoder(h) # (B*N, horizon * output_dim)
# 5) 重塑回 (B, horizon, N, output_dim)
out = out_flat.view(B, N, self.horizon, self.output_dim) \
.permute(0, 2, 1, 3)
return out