TrafficWheel/model/EXP/trash/EXP8b.py

134 lines
4.1 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""
含残差版本
"""
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
def forward(self):
# (N, D) @ (D, N) -> (N, N)
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = input_dim == output_dim
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
# x: (B, N, C) / adj: (N, N)
res = x
x = torch.matmul(adj, x) # (B, N, C)
x = self.theta(x)
# 残差连接
if self.residual:
x = x + res
else:
x = x + self.res_proj(res)
return F.relu(x)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
# x: (B, T, C)
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args['horizon']
self.output_dim = args['output_dim']
self.seq_len = args.get('in_len', 12)
self.hidden_dim = args.get('hidden_dim', 64)
self.num_nodes = args['num_nodes']
self.time_slots = args.get('time_slots', 24 * 60 // args.get('time_slot', 5))
self.time_embedding = nn.Embedding(self.time_slots, self.hidden_dim)
self.day_embedding = nn.Embedding(7, self.hidden_dim)
# 动态图构建
self.graph = DynamicGraphConstructor(self.num_nodes, embed_dim=16)
# 输入映射层
self.input_proj = nn.Linear(self.seq_len, self.hidden_dim)
# 图卷积
self.gc = GraphConvBlock(self.hidden_dim, self.hidden_dim)
# MANBA block
self.manba = MANBA_Block(self.hidden_dim, self.hidden_dim * 2)
# 输出映射
self.out_proj = nn.Linear(self.hidden_dim, self.horizon * self.output_dim)
def forward(self, x):
# x: (B, T, N, D_total)
x_time = x[..., 1] # (B, T, N)
x_day = x[..., 2] # (B, T, N)
x = x[..., 0] # 只用主通道 (B, T, N)
B, T, N = x.shape
assert T == self.seq_len
# 输入投影 (B, T, N) -> (B, N, T) -> (B*N, T) -> (B*N, H)
x = x.permute(0, 2, 1).reshape(B * N, T)
h = self.input_proj(x) # (B*N, hidden_dim)
h = h.view(B, N, self.hidden_dim)
t_idx = (x_time[:, -1, :,] * (self.time_slots - 1)).long() # (B, N)
d_idx = x_day[:, -1, :,].long() # (B, N)
time_emb = self.time_embedding(t_idx) # (B, N, hidden_dim)
day_emb = self.day_embedding(d_idx) # (B, N, hidden_dim)
# 3) inject them into the initial hidden state
h = h + time_emb + day_emb
# 动态图构建
adj = self.graph() # (N, N)
# 空间建模:图卷积
h = self.gc(h, adj) # (B, N, hidden_dim)
# 时间建模MANBA
h = self.manba(h) # (B, N, hidden_dim)
# 输出映射
out = self.out_proj(h) # (B, N, horizon * output_dim)
out = out.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3)
return out # (B, horizon, N, output_dim)