TrafficWheel/model/EXP/trash/EXP26.py

196 lines
6.6 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""
添加时间嵌入 + 引入图注意力网络GAT
"""
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
def forward(self):
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
# 原来的 GCN 块保留备用
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = (input_dim == output_dim)
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
res = x
x = torch.matmul(adj, x)
x = self.theta(x)
x = x + (res if self.residual else self.res_proj(res))
return F.relu(x)
# ★★ GAT 部分:从 LeronQ/GCN_predict-Pytorch 改写而来 ★★
class GraphAttentionLayer(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.W = nn.Linear(in_c, out_c, bias=False)
self.b = nn.Parameter(torch.Tensor(out_c))
nn.init.xavier_uniform_(self.W.weight)
nn.init.zeros_(self.b)
def forward(self, h, adj):
# h: [B, N, C_in], adj: [N, N]
Wh = self.W(h) # [B, N, C_out]
# 计算注意力得分
score = torch.bmm(Wh, Wh.transpose(1, 2)) * adj.unsqueeze(0) # [B, N, N]
score = score.masked_fill(score == 0, -1e16)
alpha = F.softmax(score, dim=-1) # [B, N, N]
# 加权求和并加偏置
out = torch.bmm(alpha, Wh) + self.b # [B, N, C_out]
return F.relu(out)
class GraphAttentionBlock(nn.Module):
def __init__(self, input_dim, output_dim, n_heads=4):
super().__init__()
# 多头注意力
self.heads = nn.ModuleList([GraphAttentionLayer(input_dim, output_dim) for _ in range(n_heads)])
# 合并后再做一次线性映射
self.out_att = GraphAttentionLayer(output_dim * n_heads, output_dim)
self.act = nn.ReLU()
def forward(self, x, adj):
# x: [B, N, C], adj: [N, N]
# 并行多头,然后拼接
h_cat = torch.cat([head(x, adj) for head in self.heads], dim=-1) # [B, N, output_dim * n_heads]
h_out = self.out_att(h_cat, adj) # [B, N, output_dim]
return self.act(h_out)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class SandwichBlock(nn.Module):
def __init__(self, num_nodes, embed_dim, hidden_dim):
super().__init__()
self.manba1 = MANBA_Block(hidden_dim, hidden_dim * 2)
self.graph_constructor = DynamicGraphConstructor(num_nodes, embed_dim)
# ★★ 替换为 GATBlock ★★
self.gc = GraphAttentionBlock(hidden_dim, hidden_dim, n_heads=4)
self.manba2 = MANBA_Block(hidden_dim, hidden_dim * 2)
def forward(self, h):
h1 = self.manba1(h)
adj = self.graph_constructor()
h2 = self.gc(h1, adj)
h3 = self.manba2(h2)
return h3
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dims, out_dim, activation=nn.ReLU):
super().__init__()
dims = [in_dim] + hidden_dims + [out_dim]
layers = []
for i in range(len(dims)-2):
layers += [nn.Linear(dims[i], dims[i+1]), activation()]
layers += [nn.Linear(dims[-2], dims[-1])]
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args['horizon']
self.output_dim = args['output_dim']
self.seq_len = args.get('in_len', 12)
self.hidden_dim = args.get('hidden_dim', 64)
self.num_nodes = args['num_nodes']
self.embed_dim = args.get('embed_dim', 16)
# ==== 新增:离散时间嵌入 ====
self.time_slots = args.get('time_slots', 24 * 60 // args.get('time_slot', 5))
self.time_embedding = nn.Embedding(self.time_slots, self.hidden_dim)
self.day_embedding = nn.Embedding(7, self.hidden_dim)
# 输入投影(仅 flow
self.input_proj = MLP(
in_dim = self.seq_len,
hidden_dims = [self.hidden_dim],
out_dim = self.hidden_dim
)
# 两个 SandwichBlock已替换为 GAT
self.sandwich1 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
self.sandwich2 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
# 输出投影
self.out_proj = MLP(
in_dim = self.hidden_dim,
hidden_dims = [2 * self.hidden_dim],
out_dim = self.horizon * self.output_dim
)
def forward(self, x):
"""
x: (B, T, N, D_total)
D_total >= 3, x[...,0]=flow, x[...,1]=time_in_day, x[...,2]=day_in_week
"""
x_flow = x[..., 0] # (B, T, N)
x_time = x[..., 1] # (B, T, N)
x_day = x[..., 2] # (B, T, N)
B, T, N = x_flow.shape
assert T == self.seq_len
# 1) 投影流量历史
x_flat = x_flow.permute(0, 2, 1).reshape(B * N, T)
h0 = self.input_proj(x_flat).view(B, N, self.hidden_dim)
# 2) 取最后一步的时间索引并嵌入
t_idx = (x_time[:, -1, :,] * (self.time_slots - 1)).long()
d_idx = x_day[:, -1, :,].long()
time_emb = self.time_embedding(t_idx)
day_emb = self.day_embedding(d_idx)
# 3) 注入时间信息
h0 = h0 + time_emb + day_emb
# 4) Sandwich + 残差
h1 = self.sandwich1(h0)
h1 = h1 + h0
h2 = self.sandwich2(h1)
# 5) 输出
out = self.out_proj(h2)
out = out.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3)
return out