TrafficWheel/model/EXP/trash/EXP14.py

147 lines
5.0 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""
含时间/空间额外特征的双层 时间->空间->时间 三明治结构模型
使用 x[...,0] 主通道x[...,1] time_in_dayx[...,2] day_in_week
通过独立投影融合三路信息
无改进
"""
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
def forward(self):
# 构造动态邻接矩阵 (N, N)
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = (input_dim == output_dim)
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
# x: (B, N, C)
res = x
x = torch.matmul(adj, x)
x = self.theta(x)
x = x + (res if self.residual else self.res_proj(res))
return F.relu(x)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
# x: (B, N, C) 视 N 维为时间序列长度
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class SandwichBlock(nn.Module):
"""
时间 -> 空间 -> 时间 三明治结构
输入/输出: (B, N, hidden_dim)
"""
def __init__(self, num_nodes, embed_dim, hidden_dim):
super().__init__()
self.manba1 = MANBA_Block(hidden_dim, hidden_dim * 2)
self.graph_constructor = DynamicGraphConstructor(num_nodes, embed_dim)
self.gc = GraphConvBlock(hidden_dim, hidden_dim)
self.manba2 = MANBA_Block(hidden_dim, hidden_dim * 2)
def forward(self, h):
# h: (B, N, hidden_dim)
# 第一步:时间注意力
h1 = self.manba1(h)
# 第二步:空间卷积
adj = self.graph_constructor()
h2 = self.gc(h1, adj)
# 第三步:时间注意力
h3 = self.manba2(h2)
return h3
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args['horizon']
self.output_dim = args['output_dim']
self.seq_len = args.get('in_len', 12)
self.hidden_dim = args.get('hidden_dim', 64)
self.num_nodes = args['num_nodes']
self.embed_dim = args.get('embed_dim', 16)
# 对三路输入分别投影到隐藏维度
self.main_proj = nn.Linear(self.seq_len, self.hidden_dim)
self.time_proj = nn.Linear(self.seq_len, self.hidden_dim)
self.week_proj = nn.Linear(self.seq_len, self.hidden_dim)
# 双层 时间->空间->时间 三明治块
self.sandwich1 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
self.sandwich2 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
# 输出映射
self.out_proj = nn.Linear(self.hidden_dim, self.horizon * self.output_dim)
def forward(self, x):
# x: (B, T, N, D_total)
x_main = x[..., 0] # (B, T, N)
x_time = x[..., 1] # (B, T, N)
x_week = x[..., 2] # (B, T, N)
B, T, N = x_main.shape
assert T == self.seq_len
# 将三路特征分别映射后叠加
x_main_flat = x_main.permute(0, 2, 1).reshape(B * N, T)
h_main = self.main_proj(x_main_flat).view(B, N, self.hidden_dim)
x_time_flat = x_time.permute(0, 2, 1).reshape(B * N, T)
h_time = self.time_proj(x_time_flat).view(B, N, self.hidden_dim)
x_week_flat = x_week.permute(0, 2, 1).reshape(B * N, T)
h_week = self.week_proj(x_week_flat).view(B, N, self.hidden_dim)
# 初始隐藏表示,融合三路信息
h0 = h_main + h_time + h_week
# 第一层三明治 + 残差
h1 = self.sandwich1(h0)
h1 = h1 + h0
# 第二层三明治
h2 = self.sandwich2(h1)
# 输出
out = self.out_proj(h2)
out = out.view(B, N, self.horizon, self.output_dim)
out = out.permute(0, 2, 1, 3) # (B, horizon, N, D_out)
return out
# 示例测试
# args = {'horizon':12,'output_dim':1,'in_len':12,'hidden_dim':64,'num_nodes':307,'embed_dim':16}
# model = EXP(args)
# x = torch.randn(16, 12, 307, 3)
# print(model(x).shape) # (16,12,307,1)