91 lines
3.3 KiB
Python
Executable File
91 lines
3.3 KiB
Python
Executable File
import math
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
class PositionalEncoding(nn.Module):
|
||
"""标准的位置编码,用于给 Transformer 输入添加位置信息"""
|
||
|
||
def __init__(self, d_model, max_len=500):
|
||
super().__init__()
|
||
pe = torch.zeros(max_len, d_model) # (max_len, d_model)
|
||
position = torch.arange(0, max_len).unsqueeze(1).float() # (max_len,1)
|
||
div_term = torch.exp(
|
||
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
|
||
)
|
||
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维
|
||
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维
|
||
self.register_buffer("pe", pe) # 不参加梯度
|
||
|
||
def forward(self, x):
|
||
# x: (T, B, d_model)
|
||
T = x.size(0)
|
||
return x + self.pe[:T].unsqueeze(1) # (T,1,d_model) 广播到 (T,B,d_model)
|
||
|
||
|
||
class EXP(nn.Module):
|
||
"""
|
||
Transformer-based 多步预测:
|
||
- 只使用 x[...,0] 作为输入通道
|
||
- 对每个节点的长度-T 序列并行应用 Transformer Encoder
|
||
- 取最后时间步的输出,通过一个 Linear 映射到 horizon * output_dim
|
||
- 重塑为 (B, horizon, N, output_dim)
|
||
"""
|
||
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
self.horizon = args["horizon"]
|
||
self.output_dim = args["output_dim"]
|
||
self.seq_len = args.get("in_len", 12)
|
||
assert self.seq_len is not None, "请在 args 中指定 in_len(输入序列长度)"
|
||
d_model = args.get("d_model", 64)
|
||
nhead = args.get("nhead", 4)
|
||
num_layers = args.get("num_layers", 2)
|
||
dim_ff = args.get("dim_feedforward", d_model * 4)
|
||
dropout = args.get("dropout", 0.1)
|
||
|
||
# 把单通道投影到 d_model
|
||
self.input_proj = nn.Linear(1, d_model)
|
||
self.pos_encoder = PositionalEncoding(d_model, max_len=self.seq_len)
|
||
|
||
encoder_layer = nn.TransformerEncoderLayer(
|
||
d_model=d_model,
|
||
nhead=nhead,
|
||
dim_feedforward=dim_ff,
|
||
dropout=dropout,
|
||
batch_first=False, # 我们用 (T, B, D) 格式
|
||
)
|
||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||
|
||
# 最后一步输出到 多步预测
|
||
self.decoder = nn.Linear(d_model, self.horizon * self.output_dim)
|
||
|
||
def forward(self, x):
|
||
# x: (B, T, N, D_total)
|
||
x_main = x[..., 0] # (B, T, N)
|
||
B, T, N = x_main.shape
|
||
assert T == self.seq_len, f"实际序列长度 {T} != 配置 in_len {self.seq_len}"
|
||
|
||
# 重排:每个节点的序列是一个独立样本
|
||
# (B, T, N) -> (B*N, T, 1)
|
||
seq = x_main.permute(0, 2, 1).reshape(B * N, T, 1)
|
||
|
||
# 投影 & 位置编码
|
||
emb = self.input_proj(seq) # (B*N, T, d_model)
|
||
emb = emb.permute(1, 0, 2) # -> (T, B*N, d_model)
|
||
emb = self.pos_encoder(emb) # 加上位置信息
|
||
|
||
# Transformer Encoder
|
||
out = self.transformer(emb) # (T, B*N, d_model)
|
||
|
||
# 取最后时刻的隐藏向量
|
||
last = out[-1, :, :] # (B*N, d_model)
|
||
|
||
# 解码为多步预测
|
||
pred_flat = self.decoder(last) # (B*N, horizon * output_dim)
|
||
|
||
# 重塑回 (B, N, horizon, output_dim) -> (B, horizon, N, output_dim)
|
||
pred = pred_flat.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3)
|
||
return pred
|