TrafficWheel/model/EXP/EXP5.py

138 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
class TemporalBlock(nn.Module):
"""
TCN 中的因果残差块,对每个节点的时间序列进行因果卷积,
保证输出长度与输入一致。
"""
def __init__(self, in_channels, out_channels, kernel_size, dilation, dropout=0.1):
super().__init__()
self.kernel_size = kernel_size
self.dilation = dilation
# 填充长度= (kernel_size-1)*dilation
self.padding = (kernel_size - 1) * dilation
# 因果卷积:在 forward 里自己做 pad不在这里传 padding 参数
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size,
padding=0, dilation=dilation)
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size,
padding=0, dilation=dilation)
# 如果通道数要变,则用 1×1 做下采样;否则直接残差
self.downsample = (nn.Conv1d(in_channels, out_channels, 1)
if in_channels != out_channels else None)
self.dropout = nn.Dropout(dropout)
self.relu = nn.ReLU()
self.norm = nn.LayerNorm(out_channels)
def forward(self, x):
# x: (B*N, C_in, T)
# 1) 因果填充:在时间维度左侧 pad
x_padded = F.pad(x, (self.padding, 0)) # pad=(left, right)
# 2) 第一层卷积
out = self.conv1(x_padded) # (B*N, C_out, T + padding)
out = self.relu(out)
out = self.dropout(out)
# 3) 第二层卷积,同样先 pad
out = F.pad(out, (self.padding, 0))
out = self.conv2(out) # (B*N, C_out, T + padding)
out = self.dropout(out)
# 4) 残差分支
res = x if self.downsample is None else self.downsample(x) # (B*N, C_out, T)
# 5) 截掉多余的前面 padding取最后 T 个时间点
out = out[..., -x.size(2):] # now out.shape == res.shape
# 6) 残差相加 + LayerNorm + ReLU
return self.relu(self.norm((out + res).permute(0, 2, 1))).permute(0, 2, 1)
class EXP(nn.Module):
"""
时空混合模型:
1. 对每个节点的长度-T 序列,用 TCN 提取时间特征;
2. 取 TCN 最后时刻的隐藏,重组为 (B, N, hidden_dim)
3. 用 Spatial SelfAttention 在节点维度上捕捉空间依赖;
4. 最后一个 Linear 将每个节点的特征映射到 horizon 步预测。
"""
def __init__(self, args):
super().__init__()
self.seq_len = args.get('in_len', 12) # 输入序列长度 T
self.horizon = args['horizon']
self.output_dim = args['output_dim']
hidden_dim = args.get('hidden_dim', 64)
tcn_layers = args.get('tcn_layers', 3)
kernel_size = args.get('kernel_size', 3)
dropout = args.get('dropout', 0.1)
nhead = args.get('nhead', 4)
# ----- Temporal Convolutional Network -----
tcn_blocks = []
in_ch = 1 # 只用主观测通道
for i in range(tcn_layers):
dilation = 2 ** i
out_ch = hidden_dim
tcn_blocks.append(
TemporalBlock(in_ch, out_ch, kernel_size, dilation, dropout)
)
in_ch = out_ch
self.tcn = nn.Sequential(*tcn_blocks)
# ----- Spatial Self-Attention -----
# 我们把节点看作 tokens特征维度 hidden_dim
# MultiheadAttention 要求输入 (S, B, E),这里 S = N
self.spatial_attn = nn.MultiheadAttention(embed_dim=hidden_dim,
num_heads=nhead,
dropout=dropout,
batch_first=False)
# 可选的 LayerNorm
self.norm_spatial = nn.LayerNorm(hidden_dim)
# ----- Decoder -----
# hidden_dim -> horizon * output_dim
self.decoder = nn.Linear(hidden_dim, self.horizon * self.output_dim)
def forward(self, x):
"""
x: (B, T, N, D_total)只用第0通道
returns: (B, horizon, N, output_dim)
"""
B, T, N, D_total = x.shape
assert T == self.seq_len, f"Expected T={self.seq_len}, got {T}"
# 1) 取主观测、并重排给 TCN
x_main = x[..., 0] # (B, T, N)
x_tcn = x_main.reshape(B * N, 1, T) # (B*N, 1, T)
# 2) TCN 提取时间特征
tcn_out = self.tcn(x_tcn) # (B*N, hidden_dim, T)
# 3) 取最后时刻特征
last = tcn_out[:, :, -1] # (B*N, hidden_dim)
h = last.view(B, N, -1) # (B, N, hidden_dim)
# 4) Spatial Attention
# 调整为 (N, B, E) 以供 MultiheadAttention
h2 = h.permute(1, 0, 2) # (N, B, hidden_dim)
attn_out, _ = self.spatial_attn(h2, h2, h2) # (N, B, hidden_dim)
attn_out = attn_out.permute(1, 0, 2) # (B, N, hidden_dim)
h_spatial = self.norm_spatial(attn_out + h) # 残差 + LayerNorm
# 5) Decoder: 每个节点映射到 horizon*output_dim
flat = h_spatial.reshape(B * N, -1) # (B*N, hidden_dim)
out_flat = self.decoder(flat) # (B*N, horizon*output_dim)
# 6) 重塑为 (B, horizon, N, output_dim)
out = out_flat.view(B, N, self.horizon, self.output_dim) \
.permute(0, 2, 1, 3)
return out