TrafficWheel/model/EXP/trash/EXP28.py

221 lines
7.1 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""
完整的 EXP 模型,基于 Greenshields 模型反推密度,并结合 LWR 守恒方程物理引导模块,支持仅流量数据输入。
"""
class FlowToDensity(nn.Module):
"""
根据 Greenshields 基本图反解密度:
q = v_f * k * (1 - k / k_j)
通过求解二次方程获得 k。
"""
def __init__(self, v_f=15.0, k_j=1.0):
super().__init__()
self.v_f = nn.Parameter(torch.tensor(v_f), requires_grad=False)
self.k_j = nn.Parameter(torch.tensor(k_j), requires_grad=False)
def forward(self, q): # q: (B, T, N)
a = -self.v_f / self.k_j
b = self.v_f
c = -q
delta = b**2 - 4 * a * c
delta = torch.clamp(delta, min=1e-6)
sqrt_delta = torch.sqrt(delta)
k1 = (-b + sqrt_delta) / (2 * a)
k2 = (-b - sqrt_delta) / (2 * a)
k = torch.where((k1 > 0) & (k1 < self.k_j), k1, k2)
return k
class FundamentalDiagram(nn.Module):
"""
Greenshields 基本图:根据密度计算速度与流量。
"""
def __init__(self, v_free=30.0, k_jam=200.0):
super().__init__()
self.v_free = nn.Parameter(torch.tensor(v_free), requires_grad=True)
self.k_jam = nn.Parameter(torch.tensor(k_jam), requires_grad=True)
def forward(self, density):
speed = self.v_free * (1 - density / self.k_jam)
flux = density * speed
return speed, flux
class ConservationLayer(nn.Module):
"""
基于 LWR 方程离散化的守恒层。
"""
def __init__(self, dt=1.0, dx=1.0):
super().__init__()
self.dt = dt
self.dx = dx
def forward(self, density, flux, adj):
# density, flux: (B, N); adj: (N, N)
# outflow: 流量从节点流出到邻居
outflow = flux @ adj
# inflow: 邻居流量流入该节点
inflow = flux @ adj.T
# 更新密度
delta = (inflow - outflow) * (self.dt / self.dx)
d_next = density + delta
return d_next.clamp(min=0.0)
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim))
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim))
def forward(self):
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = input_dim == output_dim
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
res = x
x = torch.matmul(adj, x)
x = self.theta(x)
x = x + (res if self.residual else self.res_proj(res))
return F.relu(x)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(
embed_dim=input_dim, num_heads=4, batch_first=True
)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class SandwichBlock(nn.Module):
def __init__(self, num_nodes, embed_dim, hidden_dim):
super().__init__()
self.manba1 = MANBA_Block(hidden_dim, hidden_dim * 2)
self.graph_constructor = DynamicGraphConstructor(num_nodes, embed_dim)
self.gc = GraphConvBlock(hidden_dim, hidden_dim)
self.manba2 = MANBA_Block(hidden_dim, hidden_dim * 2)
self.fundamental = FundamentalDiagram()
self.conserve = ConservationLayer()
def forward(self, h, density):
# h: (B, N, D)density: (B, N)
h1 = self.manba1(h)
adj = self.graph_constructor()
_, flux = self.fundamental(density)
density_next = self.conserve(density, flux, adj)
h1_updated = h1 + density_next.unsqueeze(-1)
h2 = self.gc(h1_updated, adj)
h3 = self.manba2(h2)
return h3, density_next
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dims, out_dim, activation=nn.ReLU):
super().__init__()
dims = [in_dim] + hidden_dims + [out_dim]
layers = []
for i in range(len(dims) - 2):
layers += [nn.Linear(dims[i], dims[i + 1]), activation()]
layers += [nn.Linear(dims[-2], dims[-1])]
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args["horizon"]
self.output_dim = args["output_dim"]
self.seq_len = args.get("in_len", 12)
self.hidden_dim = args.get("hidden_dim", 64)
self.num_nodes = args["num_nodes"]
self.embed_dim = args.get("embed_dim", 16)
self.time_slots = args.get("time_slots", 24 * 60 // args.get("time_slot", 5))
self.flow_to_density = FlowToDensity(
v_f=args.get("v_f", 15.0), k_j=args.get("k_j", 1.0)
)
self.time_embedding = nn.Embedding(self.time_slots, self.hidden_dim)
self.day_embedding = nn.Embedding(7, self.hidden_dim)
self.input_proj = MLP(
in_dim=self.seq_len, hidden_dims=[self.hidden_dim], out_dim=self.hidden_dim
)
self.sandwich1 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
self.sandwich2 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
self.out_proj = MLP(
in_dim=self.hidden_dim,
hidden_dims=[2 * self.hidden_dim],
out_dim=self.horizon * self.output_dim,
)
def forward(self, x):
"""
x: (B, T, N, 3)
x[...,0]=flow, x[...,1]=time_in_day, x[...,2]=day_in_week
"""
x_flow = x[..., 0]
x_time = x[..., 1]
x_day = x[..., 2]
B, T, N = x_flow.shape
assert T == self.seq_len
x_density = self.flow_to_density(x_flow) # (B, T, N)
dens0 = x_density[:, -1, :] # (B, N)
x_flat = x_flow.permute(0, 2, 1).reshape(B * N, T)
h0 = self.input_proj(x_flat).view(B, N, self.hidden_dim)
t_idx = (x_time[:, -1] * (self.time_slots - 1)).long()
d_idx = x_day[:, -1].long()
time_emb = self.time_embedding(t_idx)
day_emb = self.day_embedding(d_idx)
h0 = h0 + time_emb + day_emb
h1, dens1 = self.sandwich1(h0, dens0)
h1 = h1 + h0
h2, dens2 = self.sandwich2(h1, dens1)
out = self.out_proj(h2)
out = out.view(B, N, self.horizon, self.output_dim)
out = out.permute(0, 2, 1, 3)
return out