TrafficWheel/model/EXP/trash/EXP2.py

91 lines
3.3 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionalEncoding(nn.Module):
"""标准的位置编码,用于给 Transformer 输入添加位置信息"""
def __init__(self, d_model, max_len=500):
super().__init__()
pe = torch.zeros(max_len, d_model) # (max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float() # (max_len,1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维
self.register_buffer("pe", pe) # 不参加梯度
def forward(self, x):
# x: (T, B, d_model)
T = x.size(0)
return x + self.pe[:T].unsqueeze(1) # (T,1,d_model) 广播到 (T,B,d_model)
class EXP(nn.Module):
"""
Transformer-based 多步预测:
- 只使用 x[...,0] 作为输入通道
- 对每个节点的长度-T 序列并行应用 Transformer Encoder
- 取最后时间步的输出,通过一个 Linear 映射到 horizon * output_dim
- 重塑为 (B, horizon, N, output_dim)
"""
def __init__(self, args):
super().__init__()
self.horizon = args["horizon"]
self.output_dim = args["output_dim"]
self.seq_len = args.get("in_len", 12)
assert self.seq_len is not None, "请在 args 中指定 in_len输入序列长度"
d_model = args.get("d_model", 64)
nhead = args.get("nhead", 4)
num_layers = args.get("num_layers", 2)
dim_ff = args.get("dim_feedforward", d_model * 4)
dropout = args.get("dropout", 0.1)
# 把单通道投影到 d_model
self.input_proj = nn.Linear(1, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len=self.seq_len)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_ff,
dropout=dropout,
batch_first=False, # 我们用 (T, B, D) 格式
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# 最后一步输出到 多步预测
self.decoder = nn.Linear(d_model, self.horizon * self.output_dim)
def forward(self, x):
# x: (B, T, N, D_total)
x_main = x[..., 0] # (B, T, N)
B, T, N = x_main.shape
assert T == self.seq_len, f"实际序列长度 {T} != 配置 in_len {self.seq_len}"
# 重排:每个节点的序列是一个独立样本
# (B, T, N) -> (B*N, T, 1)
seq = x_main.permute(0, 2, 1).reshape(B * N, T, 1)
# 投影 & 位置编码
emb = self.input_proj(seq) # (B*N, T, d_model)
emb = emb.permute(1, 0, 2) # -> (T, B*N, d_model)
emb = self.pos_encoder(emb) # 加上位置信息
# Transformer Encoder
out = self.transformer(emb) # (T, B*N, d_model)
# 取最后时刻的隐藏向量
last = out[-1, :, :] # (B*N, d_model)
# 解码为多步预测
pred_flat = self.decoder(last) # (B*N, horizon * output_dim)
# 重塑回 (B, N, horizon, output_dim) -> (B, horizon, N, output_dim)
pred = pred_flat.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3)
return pred