TrafficWheel/model/EXP/trash/EXP23.py

179 lines
5.2 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""
添加时间嵌入 + 基于可学习邻接矩阵的图构造
"""
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num):
super().__init__()
# 直接用一个 N×N 的可学习参数矩阵来表示邻接
self.adj_param = nn.Parameter(
torch.randn(node_num, node_num), requires_grad=True
)
def forward(self):
# 非线性截断,去除负边
adj = F.relu(self.adj_param)
# 行归一化
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = input_dim == output_dim
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
# x: (B, N, C)
res = x
# 邻接乘特征
x = torch.matmul(adj, x)
x = self.theta(x)
x = x + (res if self.residual else self.res_proj(res))
return F.relu(x)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(
embed_dim=input_dim, num_heads=4, batch_first=True
)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
# x: (B, N, C)
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class SandwichBlock(nn.Module):
def __init__(self, num_nodes, hidden_dim):
super().__init__()
self.manba1 = MANBA_Block(hidden_dim, hidden_dim * 2)
self.graph_constructor = DynamicGraphConstructor(num_nodes)
self.gc = GraphConvBlock(hidden_dim, hidden_dim)
self.manba2 = MANBA_Block(hidden_dim, hidden_dim * 2)
def forward(self, h):
# h: (B, N, C)
h1 = self.manba1(h)
adj = self.graph_constructor() # (N, N)
h2 = self.gc(h1, adj)
h3 = self.manba2(h2)
return h3
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dims, out_dim, activation=nn.ReLU):
super().__init__()
dims = [in_dim] + hidden_dims + [out_dim]
layers = []
for i in range(len(dims) - 2):
layers += [nn.Linear(dims[i], dims[i + 1]), activation()]
layers += [nn.Linear(dims[-2], dims[-1])]
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args["horizon"]
self.output_dim = args["output_dim"]
self.seq_len = args.get("in_len", 12)
self.hidden_dim = args.get("hidden_dim", 64)
self.num_nodes = args["num_nodes"]
# ==== 离散时间嵌入 ====
self.time_slots = args.get("time_slots", 24 * 60 // args.get("time_slot", 5))
self.time_embedding = nn.Embedding(self.time_slots, self.hidden_dim)
self.day_embedding = nn.Embedding(7, self.hidden_dim)
# 流量历史投影
self.input_proj = MLP(
in_dim=self.seq_len, hidden_dims=[self.hidden_dim], out_dim=self.hidden_dim
)
# 两个 SandwichBlock
self.sandwich1 = SandwichBlock(self.num_nodes, self.hidden_dim)
self.sandwich2 = SandwichBlock(self.num_nodes, self.hidden_dim)
# 输出投影
self.out_proj = MLP(
in_dim=self.hidden_dim,
hidden_dims=[2 * self.hidden_dim],
out_dim=self.horizon * self.output_dim,
)
def forward(self, x):
"""
x: (B, T, N, D_total)
D_total >= 3:
x[...,0] = flow,
x[...,1] = time_in_day (0…1),
x[...,2] = day_in_week (0…6)
"""
x_flow = x[..., 0] # (B, T, N)
x_time = x[..., 1] # (B, T, N)
x_day = x[..., 2] # (B, T, N)
B, T, N = x_flow.shape
assert T == self.seq_len
# 1) 投影流量历史
x_flat = x_flow.permute(0, 2, 1).reshape(B * N, T)
h0 = self.input_proj(x_flat).view(B, N, self.hidden_dim)
# 2) 离散时间索引
t_idx = (
x_time[
:,
-1,
:,
]
* (self.time_slots - 1)
).long() # (B, N)
d_idx = x_day[
:,
-1,
:,
].long() # (B, N)
time_emb = self.time_embedding(t_idx)
day_emb = self.day_embedding(d_idx)
# 3) 注入时间嵌入
h0 = h0 + time_emb + day_emb
# 4) Sandwich + 残差
h1 = self.sandwich1(h0)
h1 = h1 + h0
h2 = self.sandwich2(h1)
# 5) 输出投影
out = self.out_proj(h2) # (B, N, horizon*output_dim)
out = out.view(B, N, self.horizon, self.output_dim)
out = out.permute(0, 2, 1, 3) # (B, horizon, N, output_dim)
return out