TrafficWheel/model/EXP/trash/EXP29.py

218 lines
6.7 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""
在原 EXP 模型基础上,添加 Haar 小波变换实现的一层小波去噪,增强时序特征。
"""
class WaveletDenoise(nn.Module):
"""
单层 Haar 小波去噪:
- 使用低通滤波器提取近似系数
- 通过转置卷积重构时序信号
"""
def __init__(self):
super().__init__()
# Haar 低通滤波器 [1/√2, 1/√2]
lp = torch.tensor([1.0, 1.0]) / (2**0.5)
self.register_buffer("lp_filter", lp.view(1, 1, 2))
# 转置卷积同滤波器
self.register_buffer("lp_rec", lp.view(1, 1, 2))
def forward(self, x):
"""
x: (B, T, N)
返回去噪后的 (B, T, N)
"""
B, T, N = x.shape
# reshape for conv1d: (B*N, 1, T)
x_flat = x.permute(0, 2, 1).contiguous().view(-1, 1, T)
# 分解
cA = F.conv1d(x_flat, self.lp_filter, stride=2, padding=0)
# 重构
# 反卷积: stride=2, output_padding=T%2
out = F.conv_transpose1d(cA, self.lp_rec, stride=2, output_padding=(T % 2))
# 裁剪至原始长度
out = out[:, :, :T]
# reshape back
x_dn = out.view(B, N, T).permute(0, 2, 1)
return x_dn
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(
torch.randn(node_num, embed_dim), requires_grad=True
)
self.nodevec2 = nn.Parameter(
torch.randn(node_num, embed_dim), requires_grad=True
)
def forward(self):
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = input_dim == output_dim
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
res = x
x = torch.matmul(adj, x)
x = self.theta(x)
x = x + (res if self.residual else self.res_proj(res))
return F.relu(x)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(
embed_dim=input_dim, num_heads=4, batch_first=True
)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class SandwichBlock(nn.Module):
def __init__(self, num_nodes, embed_dim, hidden_dim):
super().__init__()
self.manba1 = MANBA_Block(hidden_dim, hidden_dim * 2)
self.graph_constructor = DynamicGraphConstructor(num_nodes, embed_dim)
self.gc = GraphConvBlock(hidden_dim, hidden_dim)
self.manba2 = MANBA_Block(hidden_dim, hidden_dim * 2)
def forward(self, h):
h1 = self.manba1(h)
adj = self.graph_constructor()
h2 = self.gc(h1, adj)
h3 = self.manba2(h2)
return h3
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dims, out_dim, activation=nn.ReLU):
super().__init__()
dims = [in_dim] + hidden_dims + [out_dim]
layers = []
for i in range(len(dims) - 2):
layers += [nn.Linear(dims[i], dims[i + 1]), activation()]
layers += [nn.Linear(dims[-2], dims[-1])]
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args["horizon"]
self.output_dim = args["output_dim"]
self.seq_len = args.get("in_len", 12)
self.hidden_dim = args.get("hidden_dim", 64)
self.num_nodes = args["num_nodes"]
self.embed_dim = args.get("embed_dim", 16)
# ==== NEW: discrete time embeddings ====
self.time_slots = args.get("time_slots", 24 * 60 // args.get("time_slot", 5))
self.time_embedding = nn.Embedding(self.time_slots, self.hidden_dim)
self.day_embedding = nn.Embedding(7, self.hidden_dim)
# ==== NEW: 小波去噪层 ====
self.wavelet = WaveletDenoise()
# input projection now only takes the denoised flow history
self.input_proj = MLP(
in_dim=self.seq_len, hidden_dims=[self.hidden_dim], out_dim=self.hidden_dim
)
# two Sandwich blocks remain unchanged
self.sandwich1 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
self.sandwich2 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
# output projection unchanged
self.out_proj = MLP(
in_dim=self.hidden_dim,
hidden_dims=[2 * self.hidden_dim],
out_dim=self.horizon * self.output_dim,
)
def forward(self, x):
"""
x: (B, T, N, D_total)
D_total >= 3 where
x[...,0] = flow,
x[...,1] = time_in_day (0 … 1),
x[...,2] = day_in_week (06)
"""
x_flow = x[..., 0] # (B, T, N)
x_time = x[..., 1] # (B, T, N)
x_day = x[..., 2] # (B, T, N)
B, T, N = x_flow.shape
assert T == self.seq_len
# 1) 小波去噪
x_dn = self.wavelet(x_flow) # (B, T, N)
# 2) project the denoised flow history
x_flat = x_dn.permute(0, 2, 1).reshape(B * N, T)
h0 = self.input_proj(x_flat).view(B, N, self.hidden_dim)
# 3) lookup discrete time indexes at the last time step
t_idx = (
x_time[
:,
-1,
:,
]
* (self.time_slots - 1)
).long()
d_idx = x_day[
:,
-1,
:,
].long()
time_emb = self.time_embedding(t_idx) # (B, N, hidden_dim)
day_emb = self.day_embedding(d_idx) # (B, N, hidden_dim)
# 4) inject them into the initial hidden state
h0 = h0 + time_emb + day_emb
# 5) the usual Sandwich + residuals
h1 = self.sandwich1(h0)
h1 = h1 + h0
h2 = self.sandwich2(h1)
# 6) output projection
out = self.out_proj(h2) # (B, N, horizon*output_dim)
out = out.view(B, N, self.horizon, self.output_dim)
out = out.permute(0, 2, 1, 3) # (B, horizon, N, output_dim)
return out