TrafficWheel/model/EXP/trash/EXP20.py

174 lines
6.1 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""
使用多层感知机替换输入输出的 proj 层并将图卷积替换为图注意力网络GAT
"""
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
def forward(self):
# 构造可学习的邻接矩阵
adj = torch.matmul(self.nodevec1, self.nodevec2.T) # (N, N)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GATConvBlock(nn.Module):
"""
简易版 GAT 实现:
- 先对每个节点特征做线性变换
- 计算每对节点间的注意力得分
- 掩码掉非边adj == 0softmax 后做加权求和
- 加上残差并经过非线性
"""
def __init__(self, input_dim, output_dim, alpha=0.2):
super().__init__()
self.fc = nn.Linear(input_dim, output_dim, bias=False)
self.attn_fc = nn.Linear(2 * output_dim, 1, bias=False)
self.leakyrelu = nn.LeakyReLU(alpha)
self.residual = (input_dim == output_dim)
if not self.residual:
self.res_fc = nn.Linear(input_dim, output_dim, bias=False)
def forward(self, x, adj):
"""
x: (B, N, F_in)
adj: (N, N), 动态学习得到的邻接矩阵
返回 h_prime: (B, N, F_out)
"""
B, N, _ = x.shape
h = self.fc(x) # (B, N, F_out)
# 计算每对节点的注意力打分
h_i = h.unsqueeze(2).expand(-1, -1, N, -1) # (B, N, N, F_out)
h_j = h.unsqueeze(1).expand(-1, N, -1, -1) # (B, N, N, F_out)
e = self.attn_fc(torch.cat([h_i, h_j], dim=-1)).squeeze(-1) # (B, N, N)
e = self.leakyrelu(e)
# 掩码:只有 adj > 0 的位置保留注意力,否则置为 -inf
mask = adj.unsqueeze(0).expand(B, -1, -1) > 0
e = e.masked_fill(~mask, float('-inf'))
# 归一化注意力
alpha = F.softmax(e, dim=-1) # (B, N, N)
# 聚合邻居
h_prime = torch.matmul(alpha, h) # (B, N, F_out)
# 残差连接
if self.residual:
h_prime = h_prime + x
else:
h_prime = h_prime + self.res_fc(x)
return F.elu(h_prime)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
# x: (B, N, input_dim) — 将节点序列看作时间序列处理
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class SandwichBlock(nn.Module):
def __init__(self, num_nodes, embed_dim, hidden_dim):
super().__init__()
self.manba1 = MANBA_Block(hidden_dim, hidden_dim * 2)
self.graph_constructor = DynamicGraphConstructor(num_nodes, embed_dim)
self.gat = GATConvBlock(hidden_dim, hidden_dim)
self.manba2 = MANBA_Block(hidden_dim, hidden_dim * 2)
def forward(self, h):
# h: (B, N, hidden_dim)
h1 = self.manba1(h) # 自注意力 + FFN
adj = self.graph_constructor() # 动态邻接 (N, N)
h2 = self.gat(h1, adj) # GAT 聚合
h3 = self.manba2(h2) # 再一次自注意力 + FFN
return h3
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dims, out_dim, activation=nn.ReLU):
super().__init__()
dims = [in_dim] + hidden_dims + [out_dim]
layers = []
for i in range(len(dims) - 2):
layers += [nn.Linear(dims[i], dims[i + 1]), activation()]
layers += [nn.Linear(dims[-2], dims[-1])]
self.net = nn.Sequential(*layers)
def forward(self, x):
# 支持任意形状Linear 运算对最后一维有效
return self.net(x)
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args['horizon']
self.output_dim = args['output_dim']
self.seq_len = args.get('in_len', 12)
self.hidden_dim = args.get('hidden_dim', 64)
self.num_nodes = args['num_nodes']
self.embed_dim = args.get('embed_dim', 16)
# 用 MLP 替换原来的输入投影
self.input_proj = MLP(self.seq_len, [self.hidden_dim], self.hidden_dim)
# 两层 SandwichBlock
self.sandwich1 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
self.sandwich2 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
# 用 MLP 替换原来的输出投影
self.out_proj = MLP(self.hidden_dim, [2 * self.hidden_dim], self.horizon * self.output_dim)
def forward(self, x):
"""
x: (B, T, N, D_total)
假设 D_total >= 1且我们只使用第 0 维特征进行预测
返回:
out: (B, horizon, N, output_dim)
"""
x_main = x[..., 0] # (B, T, N)
B, T, N = x_main.shape
assert T == self.seq_len, f"Expected seq_len={self.seq_len}, got {T}"
# (B, T, N) -> (B, N, T) -> (B*N, T) -> MLP -> (B, N, hidden_dim)
x_flat = x_main.permute(0, 2, 1).reshape(B * N, T)
h0 = self.input_proj(x_flat).view(B, N, self.hidden_dim)
# 两层 Sandwich + 残差
h1 = self.sandwich1(h0)
h1 = h1 + h0
h2 = self.sandwich2(h1)
# 输出投影
out = self.out_proj(h2) # (B, N, horizon * output_dim)
out = out.view(B, N, self.horizon, self.output_dim)
out = out.permute(0, 2, 1, 3) # (B, horizon, N, output_dim)
return out