138 lines
5.5 KiB
Python
138 lines
5.5 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
class TemporalBlock(nn.Module):
|
||
"""
|
||
TCN 中的因果残差块,对每个节点的时间序列进行因果卷积,
|
||
保证输出长度与输入一致。
|
||
"""
|
||
def __init__(self, in_channels, out_channels, kernel_size, dilation, dropout=0.1):
|
||
super().__init__()
|
||
self.kernel_size = kernel_size
|
||
self.dilation = dilation
|
||
# 填充长度= (kernel_size-1)*dilation
|
||
self.padding = (kernel_size - 1) * dilation
|
||
|
||
# 因果卷积:在 forward 里自己做 pad,不在这里传 padding 参数
|
||
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size,
|
||
padding=0, dilation=dilation)
|
||
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size,
|
||
padding=0, dilation=dilation)
|
||
|
||
# 如果通道数要变,则用 1×1 做下采样;否则直接残差
|
||
self.downsample = (nn.Conv1d(in_channels, out_channels, 1)
|
||
if in_channels != out_channels else None)
|
||
|
||
self.dropout = nn.Dropout(dropout)
|
||
self.relu = nn.ReLU()
|
||
self.norm = nn.LayerNorm(out_channels)
|
||
|
||
def forward(self, x):
|
||
# x: (B*N, C_in, T)
|
||
# 1) 因果填充:在时间维度左侧 pad
|
||
x_padded = F.pad(x, (self.padding, 0)) # pad=(left, right)
|
||
|
||
# 2) 第一层卷积
|
||
out = self.conv1(x_padded) # (B*N, C_out, T + padding)
|
||
out = self.relu(out)
|
||
out = self.dropout(out)
|
||
|
||
# 3) 第二层卷积,同样先 pad
|
||
out = F.pad(out, (self.padding, 0))
|
||
out = self.conv2(out) # (B*N, C_out, T + padding)
|
||
out = self.dropout(out)
|
||
|
||
# 4) 残差分支
|
||
res = x if self.downsample is None else self.downsample(x) # (B*N, C_out, T)
|
||
|
||
# 5) 截掉多余的前面 padding,取最后 T 个时间点
|
||
out = out[..., -x.size(2):] # now out.shape == res.shape
|
||
|
||
|
||
# 6) 残差相加 + LayerNorm + ReLU
|
||
return self.relu(self.norm((out + res).permute(0, 2, 1))).permute(0, 2, 1)
|
||
|
||
|
||
class EXP(nn.Module):
|
||
"""
|
||
时空混合模型:
|
||
1. 对每个节点的长度-T 序列,用 TCN 提取时间特征;
|
||
2. 取 TCN 最后时刻的隐藏,重组为 (B, N, hidden_dim);
|
||
3. 用 Spatial Self‑Attention 在节点维度上捕捉空间依赖;
|
||
4. 最后一个 Linear 将每个节点的特征映射到 horizon 步预测。
|
||
"""
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
self.seq_len = args.get('in_len', 12) # 输入序列长度 T
|
||
self.horizon = args['horizon']
|
||
self.output_dim = args['output_dim']
|
||
hidden_dim = args.get('hidden_dim', 64)
|
||
tcn_layers = args.get('tcn_layers', 3)
|
||
kernel_size = args.get('kernel_size', 3)
|
||
dropout = args.get('dropout', 0.1)
|
||
nhead = args.get('nhead', 4)
|
||
|
||
# ----- Temporal Convolutional Network -----
|
||
tcn_blocks = []
|
||
in_ch = 1 # 只用主观测通道
|
||
for i in range(tcn_layers):
|
||
dilation = 2 ** i
|
||
out_ch = hidden_dim
|
||
tcn_blocks.append(
|
||
TemporalBlock(in_ch, out_ch, kernel_size, dilation, dropout)
|
||
)
|
||
in_ch = out_ch
|
||
self.tcn = nn.Sequential(*tcn_blocks)
|
||
|
||
# ----- Spatial Self-Attention -----
|
||
# 我们把节点看作 tokens,特征维度 hidden_dim
|
||
# MultiheadAttention 要求输入 (S, B, E),这里 S = N
|
||
self.spatial_attn = nn.MultiheadAttention(embed_dim=hidden_dim,
|
||
num_heads=nhead,
|
||
dropout=dropout,
|
||
batch_first=False)
|
||
|
||
# 可选的 LayerNorm
|
||
self.norm_spatial = nn.LayerNorm(hidden_dim)
|
||
|
||
# ----- Decoder -----
|
||
# hidden_dim -> horizon * output_dim
|
||
self.decoder = nn.Linear(hidden_dim, self.horizon * self.output_dim)
|
||
|
||
def forward(self, x):
|
||
"""
|
||
x: (B, T, N, D_total),只用第0通道
|
||
returns: (B, horizon, N, output_dim)
|
||
"""
|
||
B, T, N, D_total = x.shape
|
||
assert T == self.seq_len, f"Expected T={self.seq_len}, got {T}"
|
||
|
||
# 1) 取主观测、并重排给 TCN
|
||
x_main = x[..., 0] # (B, T, N)
|
||
x_tcn = x_main.reshape(B * N, 1, T) # (B*N, 1, T)
|
||
|
||
# 2) TCN 提取时间特征
|
||
tcn_out = self.tcn(x_tcn) # (B*N, hidden_dim, T)
|
||
|
||
# 3) 取最后时刻特征
|
||
last = tcn_out[:, :, -1] # (B*N, hidden_dim)
|
||
h = last.view(B, N, -1) # (B, N, hidden_dim)
|
||
|
||
# 4) Spatial Attention
|
||
# 调整为 (N, B, E) 以供 MultiheadAttention
|
||
h2 = h.permute(1, 0, 2) # (N, B, hidden_dim)
|
||
attn_out, _ = self.spatial_attn(h2, h2, h2) # (N, B, hidden_dim)
|
||
attn_out = attn_out.permute(1, 0, 2) # (B, N, hidden_dim)
|
||
h_spatial = self.norm_spatial(attn_out + h) # 残差 + LayerNorm
|
||
|
||
# 5) Decoder: 每个节点映射到 horizon*output_dim
|
||
flat = h_spatial.reshape(B * N, -1) # (B*N, hidden_dim)
|
||
out_flat = self.decoder(flat) # (B*N, horizon*output_dim)
|
||
|
||
# 6) 重塑为 (B, horizon, N, output_dim)
|
||
out = out_flat.view(B, N, self.horizon, self.output_dim) \
|
||
.permute(0, 2, 1, 3)
|
||
return out
|