TrafficWheel/model/EXP/trash/EXP11.py

124 lines
3.8 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""
含残差版本 + 时间-空间-时间(三明治结构)
"""
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
def forward(self):
# (N, D) @ (D, N) -> (N, N)
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = input_dim == output_dim
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
# x: (B, N, C) / adj: (N, N)
res = x
x = torch.matmul(adj, x) # (B, N, C)
x = self.theta(x)
# 残差连接
if self.residual:
x = x + res
else:
x = x + self.res_proj(res)
return F.relu(x)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
# x: (B, T, C) 或 (B, N, C) 当 N 视为 时间维度
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args['horizon']
self.output_dim = args['output_dim']
self.seq_len = args.get('in_len', 12)
self.hidden_dim = args.get('hidden_dim', 64)
self.num_nodes = args['num_nodes']
# 动态图构建
self.graph = DynamicGraphConstructor(self.num_nodes, embed_dim=16)
# 输入映射层
self.input_proj = nn.Linear(self.seq_len, self.hidden_dim)
# 图卷积
self.gc = GraphConvBlock(self.hidden_dim, self.hidden_dim)
# MANBA block时间建模
self.manba = MANBA_Block(self.hidden_dim, self.hidden_dim * 2)
# 输出映射
self.out_proj = nn.Linear(self.hidden_dim, self.horizon * self.output_dim)
def forward(self, x):
# x: (B, T, N, D_total)
x = x[..., 0] # 只用主通道 (B, T, N)
B, T, N = x.shape
assert T == self.seq_len
# 输入投影 (B, T, N) -> (B, N, T) -> (B*N, T) -> (B*N, hidden_dim)
x_flat = x.permute(0, 2, 1).reshape(B * N, T)
h = self.input_proj(x_flat) # (B*N, hidden_dim)
h = h.view(B, N, self.hidden_dim) # (B, N, hidden_dim)
# === 时间建模(首次) ===
# 将 N 视作 时间维度进行注意力
h_time1 = self.manba(h) # (B, N, hidden_dim)
# 动态图构建
adj = self.graph() # (N, N)
# === 空间建模 ===
h_space = self.gc(h_time1, adj) # (B, N, hidden_dim)
# === 时间建模(再一次) ===
h_time2 = self.manba(h_space) # (B, N, hidden_dim)
# 输出映射
out = self.out_proj(h_time2) # (B, N, horizon * output_dim)
out = out.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3)
return out # (B, horizon, N, output_dim)