TrafficWheel/model/EXP/trash/EXP10.py

171 lines
5.3 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""
KAN网络
"""
class KANLinear(nn.Module):
"""
A simple KolmogorovArnold Network linear layer.
y_k = sum_{q=1}^Q alpha_{kq} * phi_q( sum_{i=1}^I beta_{qi} * x_i )
"""
def __init__(self, in_features, out_features, hidden_funcs=10):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.num_hidden = hidden_funcs
# mixing weights from input to Q hidden functions
self.beta = nn.Parameter(torch.randn(hidden_funcs, in_features))
# one univariate phi function per hidden channel
self.phi = nn.ModuleList(
[nn.Sequential(nn.Linear(1, 1), nn.ReLU()) for _ in range(hidden_funcs)]
)
# mixing weights from hidden functions to outputs
self.alpha = nn.Parameter(torch.randn(out_features, hidden_funcs))
def forward(self, x):
# x: (..., in_features)
# compute univariate projections for each hidden func: u_q = sum_i beta_{qi} * x_i
u = torch.einsum("...i,qi->...q", x, self.beta) # (..., Q)
# apply phi elementwise
u_phi = torch.stack(
[
self.phi[q](u[..., q].unsqueeze(-1)).squeeze(-1)
for q in range(self.num_hidden)
],
dim=-1,
) # (..., Q)
# mix to out_features
y = torch.einsum("...q,kq->...k", u_phi, self.alpha) # (..., out_features)
return y
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(
torch.randn(node_num, embed_dim), requires_grad=True
)
self.nodevec2 = nn.Parameter(
torch.randn(node_num, embed_dim), requires_grad=True
)
def forward(self):
# (N, D) @ (D, N) -> (N, N)
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim, kan_hidden=8):
super().__init__()
self.theta = KANLinear(input_dim, output_dim, hidden_funcs=kan_hidden)
self.residual = input_dim == output_dim
if not self.residual:
self.res_proj = KANLinear(input_dim, output_dim, hidden_funcs=kan_hidden)
def forward(self, x, adj):
# x: (B, N, C) / adj: (N, N)
res = x
x = torch.matmul(adj, x)
# apply KAN-based linear mapping
B, N, C = x.shape
x = x.view(B * N, C)
x = self.theta(x)
x = x.view(B, N, -1)
if self.residual:
x = x + res
else:
x = x + self.res_proj(res.view(B * N, C)).view(B, N, -1)
return F.relu(x)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(
embed_dim=input_dim, num_heads=4, batch_first=True
)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
# x: (B, N, C) -> treat N as temporal for attention
res = x
# swap dims to (B, T, C) for attn if needed
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args["horizon"]
self.output_dim = args["output_dim"]
self.seq_len = args.get("in_len", 12)
self.hidden_dim = args.get("hidden_dim", 64)
self.num_nodes = args["num_nodes"]
kan_hidden = args.get("kan_hidden", 8)
# 动态图构建
self.graph = DynamicGraphConstructor(self.num_nodes, embed_dim=16)
# 输入映射KAN替代线性层
self.input_proj = KANLinear(
self.seq_len, self.hidden_dim, hidden_funcs=kan_hidden
)
# 图卷积
self.gc = GraphConvBlock(
self.hidden_dim, self.hidden_dim, kan_hidden=kan_hidden
)
# 时间建模保持MANBA
self.manba = MANBA_Block(self.hidden_dim, self.hidden_dim * 2)
# 输出映射KAN替代线性层
out_size = self.horizon * self.output_dim
self.out_proj = KANLinear(self.hidden_dim, out_size, hidden_funcs=kan_hidden)
def forward(self, x):
# x: (B, T, N, D_total)
x = x[..., 0]
B, T, N = x.shape
assert T == self.seq_len
# 输入投影 (B, T, N) -> (B, N, T) -> (B*N, T)
x = x.permute(0, 2, 1).reshape(B * N, T)
h = self.input_proj(x) # (B*N, hidden_dim)
h = h.view(B, N, self.hidden_dim)
# 动态图
adj = self.graph()
# 空间:图卷积
h = self.gc(h, adj)
# 时间MANBA
h = self.manba(h)
# 输出
h_flat = h.view(B * N, self.hidden_dim)
out = self.out_proj(h_flat)
out = out.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3)
return out