TrafficWheel/model/EXP/trash/EXP30.py

218 lines
7.0 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""
在 EXP 模型中添加趋势专家、周期专家和物理专家并通过门控网络Gating Network动态融合专家输出
"""
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
def forward(self):
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = (input_dim == output_dim)
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
res = x
x = torch.matmul(adj, x)
x = self.theta(x)
x = x + (res if self.residual else self.res_proj(res))
return F.relu(x)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class SandwichBlock(nn.Module):
def __init__(self, num_nodes, embed_dim, hidden_dim):
super().__init__()
self.manba1 = MANBA_Block(hidden_dim, hidden_dim * 2)
self.graph_constructor = DynamicGraphConstructor(num_nodes, embed_dim)
self.gc = GraphConvBlock(hidden_dim, hidden_dim)
self.manba2 = MANBA_Block(hidden_dim, hidden_dim * 2)
def forward(self, h):
h1 = self.manba1(h)
adj = self.graph_constructor()
h2 = self.gc(h1, adj)
h3 = self.manba2(h2)
return h3
# --------- 新增:专家网络定义 ---------
class TrendExpert(nn.Module):
"""捕捉数据中的长期趋势"""
def __init__(self, hidden_dim):
super().__init__()
self.trend_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
def forward(self, h):
return self.trend_mlp(h)
class PeriodicExpert(nn.Module):
"""捕捉周期性模式"""
def __init__(self, hidden_dim):
super().__init__()
self.periodic_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim)
)
def forward(self, h):
# 占位:可扩展为傅里叶域处理
return self.periodic_mlp(h)
class PhysicalExpert(nn.Module):
"""基于物理规律的图卷积专家"""
def __init__(self, num_nodes, embed_dim, hidden_dim):
super().__init__()
self.graph_constructor = DynamicGraphConstructor(num_nodes, embed_dim)
self.graph_conv = GraphConvBlock(hidden_dim, hidden_dim)
def forward(self, h):
adj = self.graph_constructor()
return self.graph_conv(h, adj)
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dims, out_dim, activation=nn.ReLU):
super().__init__()
dims = [in_dim] + hidden_dims + [out_dim]
layers = []
for i in range(len(dims)-2):
layers += [nn.Linear(dims[i], dims[i+1]), activation()]
layers += [nn.Linear(dims[-2], dims[-1])]
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
# 原始配置
self.horizon = args['horizon']
self.output_dim = args['output_dim']
self.seq_len = args.get('in_len', 12)
self.hidden_dim = args.get('hidden_dim', 64)
self.num_nodes = args['num_nodes']
self.embed_dim = args.get('embed_dim', 16)
# 时间嵌入
self.time_slots = args.get('time_slots', 24*60//args.get('time_slot',5))
self.time_embedding = nn.Embedding(self.time_slots, self.hidden_dim)
self.day_embedding = nn.Embedding(7, self.hidden_dim)
# 输入流量投影
self.input_proj = MLP(
in_dim=self.seq_len,
hidden_dims=[self.hidden_dim],
out_dim=self.hidden_dim
)
# --------- 新增:专家与门控网络 ---------
self.num_experts = 3
self.trend_expert = TrendExpert(self.hidden_dim)
self.periodic_expert = PeriodicExpert(self.hidden_dim)
self.physical_expert = PhysicalExpert(self.num_nodes, self.embed_dim, self.hidden_dim)
# 门控网络,根据 h0 动态生成专家权重
self.gating = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, self.num_experts)
)
# 两个 Sandwich 模块
self.sandwich1 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
self.sandwich2 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
# 输出投影
self.out_proj = MLP(
in_dim=self.hidden_dim,
hidden_dims=[2*self.hidden_dim],
out_dim=self.horizon * self.output_dim
)
def forward(self, x):
# x: (B, T, N, D_total)
x_flow = x[..., 0] # 流量
x_time = x[..., 1] # 时间槽归一化
x_day = x[..., 2] # 星期几
B, T, N = x_flow.shape
assert T == self.seq_len
# 1) 流量历史投影
x_flat = x_flow.permute(0,2,1).reshape(B*N, T)
h0 = self.input_proj(x_flat).view(B, N, self.hidden_dim)
# 2) 时间与星期嵌入
t_idx = (x_time[:, -1, :,] * (self.time_slots - 1)).long()
d_idx = x_day[:, -1, :,].long()
time_emb = self.time_embedding(t_idx)
day_emb = self.day_embedding(d_idx)
# 注入
h0 = h0 + time_emb + day_emb
# 3) 门控融合专家输出
g = self.gating(h0) # (B, N, 3)
g = F.softmax(g, dim=-1)
h_trend = self.trend_expert(h0)
h_periodic = self.periodic_expert(h0)
h_physical = self.physical_expert(h0)
# 加权相加
h0 = g[..., 0:1] * h_trend + g[..., 1:2] * h_periodic + g[..., 2:3] * h_physical
# 4) Sandwich + 残差
h1 = self.sandwich1(h0)
h1 = h1 + h0
h2 = self.sandwich2(h1)
# 5) 输出
out = self.out_proj(h2)
out = out.view(B, N, self.horizon, self.output_dim)
out = out.permute(0,2,1,3)
return out