TrafficWheel/model/EXP/trash/EXP18.py

135 lines
4.4 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""
频域处理版双层三明治结构模型:
1. 先做傅里叶变换 -> 频域中做三明治结构(时间-空间-时间)
2. 处理完成后回到时域 -> 输出预测
"""
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
def forward(self):
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = (input_dim == output_dim)
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
# x: (B, N, C)
res = x
x = torch.matmul(adj, x)
x = self.theta(x)
x = x + (res if self.residual else self.res_proj(res))
return F.relu(x)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
# x: (B, N, C)
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class SandwichBlock(nn.Module):
"""
时间-空间-时间结构
输入/输出: (B, N, C)
"""
def __init__(self, num_nodes, embed_dim, hidden_dim):
super().__init__()
self.manba1 = MANBA_Block(hidden_dim, hidden_dim * 2)
self.graph_constructor = DynamicGraphConstructor(num_nodes, embed_dim)
self.gc = GraphConvBlock(hidden_dim, hidden_dim)
self.manba2 = MANBA_Block(hidden_dim, hidden_dim * 2)
def forward(self, h):
h1 = self.manba1(h)
adj = self.graph_constructor()
h2 = self.gc(h1, adj)
h3 = self.manba2(h2)
return h3
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args['horizon']
self.output_dim = args['output_dim']
self.seq_len = args.get('in_len', 12)
self.hidden_dim = args.get('hidden_dim', 64)
self.num_nodes = args['num_nodes']
self.embed_dim = args.get('embed_dim', 16)
self.freq_len = self.seq_len // 2 + 1 # rfft输出的频率维度
# 映射到频域隐藏维度
self.freq_proj = nn.Linear(self.freq_len * 2, self.hidden_dim)
# 频域中的三明治结构
self.sandwich1 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
self.sandwich2 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
# 回到频域 -> 时域输出
self.out_proj = nn.Linear(self.hidden_dim, self.horizon * self.output_dim)
def forward(self, x):
# x: (B, T, N, D_total)
x_main = x[..., 0] # (B, T, N)
B, T, N = x_main.shape
assert T == self.seq_len
# 傅里叶变换:对每个节点的时间序列进行 rfft
Xf = torch.fft.rfft(x_main, dim=1) # (B, F, N), complex
# 拆分实部虚部,堆叠为 real + imag 两通道
real = Xf.real.permute(0, 2, 1) # (B, N, F)
imag = Xf.imag.permute(0, 2, 1) # (B, N, F)
freq_input = torch.cat([real, imag], dim=-1) # (B, N, 2F)
# 维度映射
h = self.freq_proj(freq_input) # (B, N, hidden_dim)
# 在频域中做三明治结构
h1 = self.sandwich1(h)
h1 = h1 + h # 残差连接
h2 = self.sandwich2(h1)
# 输出映射到频率域(输出 horizon * output_dim
out_freq = self.out_proj(h2) # (B, N, H*D)
out_freq = out_freq.view(B, N, self.horizon, self.output_dim)
# 将频域预测简单映射为时域结果
out = out_freq.permute(0, 2, 1, 3) # (B, horizon, N, output_dim)
return out