TrafficWheel/model/EXP/trash/EXP22.py

189 lines
6.4 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""
添加空间嵌入
"""
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
# 自适应邻接参数
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
def forward(self):
# 构造动态邻接矩阵
adj = torch.matmul(self.nodevec1, self.nodevec2.T) # (N, N)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
# 线性变换 + 可选残差投影
self.theta = nn.Linear(input_dim, output_dim)
self.residual = (input_dim == output_dim)
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
# x: (B, N, F_in), adj: (N, N)
res = x
x = torch.matmul(adj, x) # 邻接乘特征
x = self.theta(x) # 线性变换
# 残差连接
x = x + (res if self.residual else self.res_proj(res))
return F.relu(x)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
# 空间自注意力 + 前馈网络
self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
# x: (B, N, C)
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class SandwichBlock(nn.Module):
def __init__(self, num_nodes, embed_dim, hidden_dim):
super().__init__()
self.manba1 = MANBA_Block(hidden_dim, hidden_dim * 2)
self.graph_constructor = DynamicGraphConstructor(num_nodes, embed_dim)
self.gc = GraphConvBlock(hidden_dim, hidden_dim)
self.manba2 = MANBA_Block(hidden_dim, hidden_dim * 2)
def forward(self, h):
# h: (B, N, hidden_dim)
h1 = self.manba1(h)
adj = self.graph_constructor()
h2 = self.gc(h1, adj)
h3 = self.manba2(h2)
return h3
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dims, out_dim, activation=nn.ReLU):
super().__init__()
# 多层感知机
dims = [in_dim] + hidden_dims + [out_dim]
layers = []
for i in range(len(dims) - 2):
layers += [nn.Linear(dims[i], dims[i + 1]), activation()]
layers += [nn.Linear(dims[-2], dims[-1])]
self.net = nn.Sequential(*layers)
def forward(self, x):
# 对最后一维做线性映射
return self.net(x)
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
# 训练 & 输出参数
self.horizon = args['horizon']
self.output_dim = args['output_dim']
self.seq_len = args.get('in_len', 12)
self.hidden_dim = args.get('hidden_dim', 64)
self.num_nodes = args['num_nodes']
self.embed_dim = args.get('embed_dim', 16)
# ==== 时间嵌入 ====
self.time_slots = args.get('time_slots', 24 * 60 // args.get('time_slot', 5))
self.time_embedding = nn.Embedding(self.time_slots, self.hidden_dim)
self.day_embedding = nn.Embedding(7, self.hidden_dim)
self.node_emb = nn.Parameter(torch.empty(self.num_nodes, self.embed_dim))
# ==== 空间嵌入 ====
# 每个节点一个可学习的向量
self.spatial_embedding = nn.Parameter(
torch.randn(self.num_nodes, self.hidden_dim),
requires_grad=True
)
# 输入投影:仅对流量序列做 MLP
self.input_proj = MLP(
in_dim=self.seq_len,
hidden_dims=[self.hidden_dim],
out_dim=self.hidden_dim
)
# 两个 SandwichBlock
self.sandwich1 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
self.sandwich2 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
# 输出投影
self.out_proj = MLP(
in_dim=self.hidden_dim,
hidden_dims=[2 * self.hidden_dim],
out_dim=self.horizon * self.output_dim
)
def forward(self, x):
"""
x: (B, T, N, D_total)
D_total >= 3其中
x[...,0] = 流量 (flow)
x[...,1] = 当天时间比 (time_in_day归一化到 [0,1])
x[...,2] = 星期几 (day_in_week06)
"""
# 拆分三条序列
x_flow = x[..., 0] # (B, T, N)
x_time = x[..., 1] # (B, T, N)
x_day = x[..., 2] # (B, T, N)
B, T, N = x_flow.shape
assert T == self.seq_len, f"序列长度应为 {self.seq_len},但收到 {T}"
# 1) MLP 投影流量历史 -> 节点初始特征 h0
x_flat = x_flow.permute(0, 2, 1).reshape(B * N, T) # (B*N, T)
h0 = self.input_proj(x_flat).view(B, N, self.hidden_dim) # (B, N, hidden_dim)
# 2) 计算离散时间嵌入
t_idx = (x_time[:, -1, :] * (self.time_slots - 1)).long() # (B, N)
d_idx = x_day[:, -1, :].long() # (B, N)
time_emb = self.time_embedding(t_idx) # (B, N, hidden_dim)
day_emb = self.day_embedding(d_idx) # (B, N, hidden_dim)
# 3) 计算空间嵌入并扩展到 batch 大小
# node_emb = []
# node_emb.append(self.node_emb.unsqueeze(0).expand(
# B, -1, -1).transpose(1, 2).unsqueeze(-1))
# spatial_emb = torch.stack(node_emb)
spatial_emb = self.spatial_embedding.unsqueeze(0).expand(B, N, self.hidden_dim) # -> (B, N, hidden_dim)
# 4) 将三种嵌入相加到 h0
h0 = h0 + time_emb + day_emb + spatial_emb
# 5) 两层 Sandwich + 残差连接
h1 = self.sandwich1(h0)
h1 = h1 + h0
h2 = self.sandwich2(h1)
# 6) 输出投影 -> (B, horizon, N, output_dim)
out = self.out_proj(h2) # (B, N, horizon*out_dim)
out = out.view(B, N, self.horizon, self.output_dim)
out = out.permute(0, 2, 1, 3) # (B, horizon, N, output_dim)
return out