TrafficWheel/model/EXP/trash/EXP17.py

171 lines
5.3 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""
基于傅里叶变换优化的双层三明治结构模型
新增TemporalFourierBlock 用于全局捕捉时序频域特征,提升预测精度
第一层Fourier 时域 -> 空间 -> 时间
残差连接:层输出 + 层输入
第二层:同样三明治结构 -> 最终输出
"""
class TemporalFourierBlock(nn.Module):
"""
时序傅里叶变换块
输入: x (B, T, N)
输出:时域重构 (B, T, N)
"""
def __init__(self, seq_len):
super().__init__()
# 频域系数学习:对每个频率分量应用可学习缩放
# rfft 输出频率数 = seq_len//2 + 1
freq_len = seq_len // 2 + 1
self.scale = nn.Parameter(torch.randn(freq_len), requires_grad=True)
self.seq_len = seq_len
def forward(self, x):
# x: (B, T, N)
# FFT 到频域
Xf = torch.fft.rfft(x, dim=1) # (B, F, N), complex
# 学习缩放:实部和虚部同时缩放
scale = self.scale.view(1, -1, 1)
Xf = Xf * scale
# IFFT 回时域
x_rec = torch.fft.irfft(Xf, n=self.seq_len, dim=1) # (B, T, N)
return x_rec
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(
torch.randn(node_num, embed_dim), requires_grad=True
)
self.nodevec2 = nn.Parameter(
torch.randn(node_num, embed_dim), requires_grad=True
)
def forward(self):
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = input_dim == output_dim
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
# x: (B, N, C); adj: (N, N)
res = x
x = torch.matmul(adj, x)
x = self.theta(x)
x = x + (res if self.residual else self.res_proj(res))
return F.relu(x)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(
embed_dim=input_dim, num_heads=4, batch_first=True
)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
# x: (B, N, C) 视 N 维为时间序列长度
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class SandwichBlock(nn.Module):
"""
时间-空间-时间 三明治结构
输入/输出: (B, N, hidden_dim)
"""
def __init__(self, num_nodes, embed_dim, hidden_dim):
super().__init__()
self.manba1 = MANBA_Block(hidden_dim, hidden_dim * 2)
self.graph_constructor = DynamicGraphConstructor(num_nodes, embed_dim)
self.gc = GraphConvBlock(hidden_dim, hidden_dim)
self.manba2 = MANBA_Block(hidden_dim, hidden_dim * 2)
def forward(self, h):
# h: (B, N, hidden_dim)
h1 = self.manba1(h)
adj = self.graph_constructor()
h2 = self.gc(h1, adj)
h3 = self.manba2(h2)
return h3
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args["horizon"]
self.output_dim = args["output_dim"]
self.seq_len = args.get("in_len", 12)
self.hidden_dim = args.get("hidden_dim", 64)
self.num_nodes = args["num_nodes"]
self.embed_dim = args.get("embed_dim", 16)
# 时序傅里叶块
self.fourier_block = TemporalFourierBlock(self.seq_len)
# 输入映射:(B*N, T) -> hidden_dim
self.input_proj = nn.Linear(self.seq_len, self.hidden_dim)
# 两层三明治块
self.sandwich1 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
self.sandwich2 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
# 输出映射
self.out_proj = nn.Linear(self.hidden_dim, self.horizon * self.output_dim)
def forward(self, x):
# x: (B, T, N, D_total)
x_main = x[..., 0] # (B, T, N)
B, T, N = x_main.shape
assert T == self.seq_len
# 时序傅里叶变换 + 残差
x_freq = self.fourier_block(x_main) # (B, T, N)
x_main = x_main + x_freq
# 输入投影 (B, T, N) -> (B*N, T) -> (B, N, hidden_dim)
x_flat = x_main.permute(0, 2, 1).reshape(B * N, T)
h0 = self.input_proj(x_flat).view(B, N, self.hidden_dim)
# 第一层三明治 + 残差
h1 = self.sandwich1(h0)
h1 = h1 + h0
# 第二层三明治
h2 = self.sandwich2(h1)
# 输出映射
out = self.out_proj(h2) # (B, N, H*D_out)
out = out.view(B, N, self.horizon, self.output_dim)
out = out.permute(0, 2, 1, 3) # (B, horizon, N, output_dim)
return out