TrafficWheel/model/EXP/EXP31.py

151 lines
4.8 KiB
Python
Executable File

import torch
import torch.nn as nn
import torch.nn.functional as F
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(
torch.randn(node_num, embed_dim), requires_grad=True
)
self.nodevec2 = nn.Parameter(
torch.randn(node_num, embed_dim), requires_grad=True
)
def forward(self):
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = input_dim == output_dim
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
res = x
x = torch.matmul(adj, x)
x = self.theta(x)
x = x + (res if self.residual else self.res_proj(res))
return F.relu(x)
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads=4, dim_feedforward=None):
super().__init__()
# feedforward dimension defaults to 2*embed_dim if not provided
ff_dim = dim_feedforward if dim_feedforward is not None else 2 * embed_dim
self.layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim, batch_first=True
)
def forward(self, x):
# x: (batch, seq_len, embed_dim)
return self.layer(x)
class SandwichBlock(nn.Module):
def __init__(self, num_nodes, embed_dim, hidden_dim, num_heads=4):
super().__init__()
self.transformer1 = TransformerBlock(
hidden_dim, num_heads=num_heads, dim_feedforward=hidden_dim * 2
)
self.graph_constructor = DynamicGraphConstructor(num_nodes, embed_dim)
self.gc = GraphConvBlock(hidden_dim, hidden_dim)
self.transformer2 = TransformerBlock(
hidden_dim, num_heads=num_heads, dim_feedforward=hidden_dim * 2
)
def forward(self, h):
# h: (batch, num_nodes, hidden_dim)
h1 = self.transformer1(h)
adj = self.graph_constructor()
h2 = self.gc(h1, adj)
h3 = self.transformer2(h2)
return h3
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dims, out_dim, activation=nn.ReLU):
super().__init__()
dims = [in_dim] + hidden_dims + [out_dim]
layers = []
for i in range(len(dims) - 2):
layers += [nn.Linear(dims[i], dims[i + 1]), activation()]
layers += [nn.Linear(dims[-2], dims[-1])]
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args["horizon"]
self.output_dim = args["output_dim"]
self.seq_len = args.get("in_len", 12)
self.hidden_dim = args.get("hidden_dim", 64)
self.num_nodes = args["num_nodes"]
self.embed_dim = args.get("embed_dim", 16)
# time embeddings
self.time_slots = args.get("time_slots", 24 * 60 // args.get("time_slot", 5))
self.time_embedding = nn.Embedding(self.time_slots, self.hidden_dim)
self.day_embedding = nn.Embedding(7, self.hidden_dim)
# input projection
self.input_proj = MLP(
in_dim=self.seq_len, hidden_dims=[self.hidden_dim], out_dim=self.hidden_dim
)
# two Sandwich blocks with transformer
self.sandwich1 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
self.sandwich2 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
# output projection
self.out_proj = MLP(
in_dim=self.hidden_dim,
hidden_dims=[2 * self.hidden_dim],
out_dim=self.horizon * self.output_dim,
)
def forward(self, x):
# x: (B, T, N, D_total)
x_flow = x[..., 0]
x_time = x[..., 1]
x_day = x[..., 2]
B, T, N = x_flow.shape
assert T == self.seq_len
# project flow history
x_flat = x_flow.permute(0, 2, 1).reshape(B * N, T)
h0 = self.input_proj(x_flat).view(B, N, self.hidden_dim)
# time embeddings
t_idx = (x_time[:, -1, :] * (self.time_slots - 1)).long()
d_idx = x_day[:, -1, :].long()
time_emb = self.time_embedding(t_idx)
day_emb = self.day_embedding(d_idx)
# inject embeddings
h0 = h0 + time_emb + day_emb
# Sandwich blocks with residuals
h1 = self.sandwich1(h0)
h1 = h1 + h0
h2 = self.sandwich2(h1)
# output
out = self.out_proj(h2)
out = out.view(B, N, self.horizon, self.output_dim)
out = out.permute(0, 2, 1, 3)
return out