TrafficWheel/model/EXP/EXP32.py

183 lines
6.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# ------------------------- CycleNet Component -------------------------
class RecurrentCycle(nn.Module):
"""Efficient cyclic data removal/addition."""
def __init__(self, cycle_len, channel_size):
super().__init__()
self.cycle_len = cycle_len
self.channel_size = channel_size
# 初始化周期缓存shape (cycle_len, channel_size)
self.data = nn.Parameter(torch.zeros(cycle_len, channel_size))
def forward(self, index, length):
# index: (B,), length: seq_len 或 pred_len
B = index.size(0)
# 生成 [0,1,...,length-1] 的偏移shape (1, length)
arange = torch.arange(length, device=index.device).unsqueeze(0)
# 对每条样本的起始 index 加 arange 并对 cycle_len 取模
idx = (index.unsqueeze(1) + arange) % self.cycle_len # (B, length)
# 返回对应的周期值 (B, length, channel_size)
return self.data[idx]
# ------------------------- Core Blocks -------------------------
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim))
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim))
def forward(self):
adj = F.relu(torch.matmul(self.nodevec1, self.nodevec2.T))
return F.softmax(adj, dim=-1)
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = input_dim == output_dim
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
res = x
x = torch.matmul(adj, x)
x = self.theta(x)
if not self.residual:
res = self.res_proj(res)
return F.relu(x + res)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(
embed_dim=input_dim, num_heads=4, batch_first=True
)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
return self.norm2(res2 + x_ffn)
class SandwichBlock(nn.Module):
def __init__(self, num_nodes, embed_dim, hidden_dim):
super().__init__()
self.manba1 = MANBA_Block(hidden_dim, hidden_dim * 2)
self.graph_constructor = DynamicGraphConstructor(num_nodes, embed_dim)
self.gc = GraphConvBlock(hidden_dim, hidden_dim)
self.manba2 = MANBA_Block(hidden_dim, hidden_dim * 2)
def forward(self, h):
h1 = self.manba1(h)
adj = self.graph_constructor()
h2 = self.gc(h1, adj)
return self.manba2(h2)
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dims, out_dim, activation=nn.ReLU):
super().__init__()
dims = [in_dim] + hidden_dims + [out_dim]
layers = []
for i in range(len(dims) - 2):
layers += [nn.Linear(dims[i], dims[i + 1]), activation()]
layers.append(nn.Linear(dims[-2], dims[-1]))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
# ------------------------- EXP with CycleNet -------------------------
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args["horizon"] # 预测步长
self.output_dim = args["output_dim"] # 输出维度 (一般=1)
self.seq_len = args.get("in_len", 12) # 输入序列长度
self.hidden_dim = args.get("hidden_dim", 64)
self.num_nodes = args["num_nodes"]
self.embed_dim = args.get("embed_dim", 16)
# 时间嵌入
self.time_slots = args.get("time_slots", 288)
self.time_embedding = nn.Embedding(self.time_slots, self.hidden_dim)
self.day_embedding = nn.Embedding(7, self.hidden_dim)
# CycleNet
self.cycleQueue = RecurrentCycle(
cycle_len=args["cycle_len"], channel_size=self.num_nodes
)
# 输入投影 (序列长度 -> 隐藏维度)
self.input_proj = MLP(self.seq_len, [self.hidden_dim], self.hidden_dim)
# 两层 Sandwich
self.sandwich1 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
self.sandwich2 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
# 输出投影
self.out_proj = MLP(
self.hidden_dim, [2 * self.hidden_dim], self.horizon * self.output_dim
)
def forward(self, x, cycle_index):
# x: (B, T, N, D>=3)
# 1) 拆流量和时间特征,保证丢掉通道维
x_flow = x[..., 0] # -> (B, T, N) or (B, T, N, 1) 如果之前切片错用了0:1
x_time = x[..., 1]
x_day = x[..., 2]
B, T, N = x_flow.shape
# DEBUG 打印(可删除)
# print("DEBUG x_flow.dim(), shape:", x_flow.dim(), x_flow.shape)
# 2) 去周期化
cyc = self.cycleQueue(cycle_index, T).squeeze(1) # (B, T, N)
x_flow = x_flow - cyc
# 3) 序列投影
h0 = x_flow.permute(0, 2, 1).reshape(B * N, T) # -> (B*N, T)
h0 = self.input_proj(h0).view(B, N, self.hidden_dim)
# 4) 加时间嵌入
t_idx = (x_time[:, -1] * (self.time_slots - 1)).long() # (B, N)
d_idx = x_day[:, -1].long() # (B, N)
h0 = h0 + self.time_embedding(t_idx) + self.day_embedding(d_idx)
# 5) Sandwich Blocks
h1 = self.sandwich1(h0) + h0
h2 = self.sandwich2(h1)
# 6) 输出投影并 reshape
out = self.out_proj(h2) # (B, N, H*O)
out = out.view(B, N, self.horizon, self.output_dim) # (B, N, H, O)
out = out.permute(0, 2, 1, 3) # (B, H, N, O)
# 加回周期
idx_out = (cycle_index + self.seq_len) % self.cycleQueue.cycle_len
cyc_out = self.cycleQueue(idx_out, self.horizon) # (B, 1, H, N)
# squeeze 掉第1维并 unsqueeze 最后一维
cyc_out = cyc_out.squeeze(1).unsqueeze(-1) # (B, H, N, 1)
# 加回周期分量
return out + cyc_out