TrafficWheel/model/EXP/trash/EXP4.py

83 lines
2.7 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualMLPBlock(nn.Module):
"""
一个隐藏维度下的残差块:
x -> Linear(hidden->hidden) -> ReLU -> Dropout
-> Linear(hidden->hidden) -> Dropout
+ 残差跳连 -> LayerNorm
"""
def __init__(self, hidden_dim, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(hidden_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.drop = nn.Dropout(dropout)
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, x):
resid = x
x = F.relu(self.fc1(x))
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return self.norm(x + resid)
class EXP(nn.Module):
"""
带残差连接的多层 MLP 预测模型:
- 输入 x: (B, T, N, D_total),使用 x[...,0]。
- seq_len=T 的序列先投影到 hidden_dim
再经过 num_blocks 个 ResidualMLPBlock。
- 最后投影到 horizon * output_dim重塑为 (B, horizon, N, output_dim)。
"""
def __init__(self, args):
super().__init__()
self.horizon = args["horizon"]
self.output_dim = args["output_dim"]
self.seq_len = args.get("in_len", 12) # 序列长度 T默认 12
hidden_dim = args.get("hidden_dim", 64)
num_blocks = args.get("num_mlp_layers", 2)
dropout = args.get("dropout", 0.1)
# 1) 输入投影T -> hidden_dim
self.input_proj = nn.Linear(self.seq_len, hidden_dim)
self.input_drop = nn.Dropout(dropout)
# 2) 残差 MLP 块
self.blocks = nn.ModuleList(
[ResidualMLPBlock(hidden_dim, dropout=dropout) for _ in range(num_blocks)]
)
# 3) 输出投影hidden_dim -> horizon * output_dim
self.decoder = nn.Linear(hidden_dim, self.horizon * self.output_dim)
def forward(self, x):
# x: (B, T, N, D_total)
x_main = x[..., 0] # (B, T, N)
B, T, N = x_main.shape
assert T == self.seq_len, f"期望序列长度 {self.seq_len}, 实际 {T}"
# 每个节点的长度-T 序列作为独立样本
h_in = x_main.permute(0, 2, 1).reshape(B * N, T) # (B*N, T)
# 1) 输入投影 + Dropout
h = F.relu(self.input_proj(h_in)) # (B*N, hidden_dim)
h = self.input_drop(h)
# 2) 残差块堆叠
for block in self.blocks:
h = block(h) # (B*N, hidden_dim)
# 3) 解码到 horizon * output_dim
out_flat = self.decoder(h) # (B*N, horizon * output_dim)
# 4) 重塑为 (B, horizon, N, output_dim)
out = out_flat.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3)
return out