TrafficWheel/model/PDG2SEQ/PDG2Seqb.py

244 lines
8.4 KiB
Python
Executable File

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple
class HyperMLP(nn.Module):
"""General hypernetwork with configurable hidden dims"""
def __init__(self, in_dim: int, out_dim: int, hidden_dims: List[int], activation=nn.Sigmoid):
super().__init__()
layers = []
dims = [in_dim] + hidden_dims + [out_dim]
for i in range(len(dims)-1):
layers.append(nn.Linear(dims[i], dims[i+1]))
if i < len(dims)-2:
layers.append(activation())
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class ChebConv(nn.Module):
"""Chebyshev graph convolution supporting learnable weights per node"""
def __init__(
self,
in_channels: int,
out_channels: int,
cheb_k: int,
embed_dim: int,
hidden_dims: List[int] # renamed to match instantiation
):
super().__init__()
# weight pool: [embed_dim, cheb_k*2+1, in_channels, out_channels]
self.weight_pool = nn.Parameter(
torch.Tensor(embed_dim, cheb_k*2+1, in_channels, out_channels)
)
self.bias_pool = nn.Parameter(torch.Tensor(embed_dim, out_channels))
# hypernetwork for dynamic mix coefficients
# uses hidden_dims
self.hyper = HyperMLP(in_dim=in_channels + out_channels, out_dim=embed_dim, hidden_dims=hidden_dims)
self.k = cheb_k
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight_pool)
nn.init.zeros_(self.bias_pool)
def forward(
self,
x: torch.Tensor, # [B, N, in_channels]
supports: List[torch.Tensor] # list of [N, N] adjacency matrices
) -> torch.Tensor:
B, N, C_in = x.shape
# compute hyper coefficients per batch
agg = x.mean(dim=1) # [B, in_channels]
coeff = self.hyper(agg) # [B, embed_dim]
# build dynamic weights and bias
# weights_pool: [E, K, in, out]; coeff: [B, E] -> [B, K, in, out]
W = torch.einsum('be,ekio->bkio', coeff, self.weight_pool) # [B, K, in, out]
b = torch.einsum('be,eo->bo', coeff, self.bias_pool) # [B, out]
# gather K+1 supports: original and repeated diffusions
x0 = x # [B, N, in_channels]
x_list = [x0]
for support in supports:
x1 = torch.einsum('ij,bjk->bik', support, x0)
x_list.append(x1)
xk = x1
for _ in range(2, self.k+1):
xk = torch.einsum('ij,bjk->bik', support, xk)
x_list.append(xk)
# stack to [B, N, K+1, in]
x_stack = torch.stack(x_list, dim=2)
# apply weights
out = torch.einsum('bnki,bkio->bno', x_stack, W) + b.unsqueeze(1)
return out
class PDG2SeqCell(nn.Module):
def __init__(
self,
node_num: int,
in_dim: int,
hidden_dim: int,
cheb_k: int,
embed_dim: int,
time_dim: int
):
super().__init__()
self.hidden_dim = hidden_dim
merge_dim = in_dim + hidden_dim
# gates and candidate use ChebConv with hidden_dims=[time_dim]
self.conv_gate = ChebConv(merge_dim, 2*hidden_dim, cheb_k, embed_dim, hidden_dims=[time_dim])
self.conv_cand = ChebConv(merge_dim, hidden_dim, cheb_k, embed_dim, hidden_dims=[time_dim])
def forward(
self,
x: torch.Tensor, # [B, N, in_dim]
h_prev: torch.Tensor, # [B, N, hidden_dim]
supports: List[torch.Tensor] # dynamic supports
) -> torch.Tensor:
merged = torch.cat([x, h_prev], dim=-1)
gates = self.conv_gate(merged, supports)
z, r = gates.chunk(2, dim=-1)
z, r = torch.sigmoid(z), torch.sigmoid(r)
merged_cand = torch.cat([x, z * h_prev], dim=-1)
h_tilde = torch.tanh(self.conv_cand(merged_cand, supports))
h = r * h_prev + (1 - r) * h_tilde
return h
def init_hidden(self, batch_size: int) -> torch.Tensor:
return torch.zeros(batch_size, self.node_num, self.hidden_dim, device=next(self.parameters()).device)
class PDG2SeqEncoder(nn.Module):
def __init__(
self,
node_num: int,
input_dim: int,
hidden_dim: int,
cheb_k: int,
embed_dim: int,
time_dim: int,
num_layers: int = 1
):
super().__init__()
self.cells = nn.ModuleList([
PDG2SeqCell(node_num, input_dim if i==0 else hidden_dim,
hidden_dim, cheb_k, embed_dim, time_dim)
for i in range(num_layers)
])
def forward(
self,
x_seq: torch.Tensor, # [B, T, N, in_dim]
h0: torch.Tensor, # [num_layers, B, N, hidden]
supports_seq: List[List[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
B, T, N, _ = x_seq.shape
states = []
inputs = x_seq
for layer, cell in enumerate(self.cells):
h = h0[layer]
outs = []
for t in range(T):
h = cell(inputs[:, t], h, supports_seq[t])
outs.append(h)
inputs = torch.stack(outs, dim=1)
states.append(h)
return inputs, torch.stack(states, dim=0)
class PDG2SeqDecoder(nn.Module):
def __init__(
self,
node_num: int,
input_dim: int,
hidden_dim: int,
cheb_k: int,
embed_dim: int,
time_dim: int,
num_layers: int = 1
):
super().__init__()
self.cells = nn.ModuleList([
PDG2SeqCell(node_num, input_dim if i==0 else hidden_dim,
hidden_dim, cheb_k, embed_dim, time_dim)
for i in range(num_layers)
])
def forward(
self,
go: torch.Tensor, # [B, N, input_dim]
h_prev: torch.Tensor, # [num_layers, B, N, hidden]
supports: List[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
new_states = []
inp = go
for layer, cell in enumerate(self.cells):
h = cell(inp, h_prev[layer], supports)
new_states.append(h)
inp = h
return inp, torch.stack(new_states, dim=0)
class PDG2Seq(nn.Module):
def __init__(self, args):
super().__init__()
self.num_nodes = args['num_nodes']
self.horizon = args['horizon']
self.rnn = args['rnn_units']
self.encoder = PDG2SeqEncoder(
self.num_nodes, args['input_dim'], args['rnn_units'],
args['cheb_k'], args['embed_dim'], args['time_dim'], args['num_layers']
)
self.decoder = PDG2SeqDecoder(
self.num_nodes, args['output_dim'], args['rnn_units'],
args['cheb_k'], args['embed_dim'], args['time_dim'], args['num_layers']
)
self.proj = nn.Linear(args['rnn_units'], args['output_dim'])
self.time_emb = nn.Embedding(288, args['time_dim'])
self.week_emb = nn.Embedding(7, args['time_dim'])
self.cl_decay = args['lr_decay_step']
def forward(
self,
source: torch.Tensor, # [B, T1, N, D]
target: torch.Tensor = None,
batches_seen: int = 0,
supports_seq: List[List[torch.Tensor]] = None,
) -> torch.Tensor:
# time embeddings
t_idx = (source[..., -2] * 287).long()
w_idx = source[..., -1].long()
# can extend embedding combination
supports_seq = supports_seq or [[] for _ in range(source.size(1))]
B = source.size(0)
# init hidden
h0 = torch.zeros(
len(self.encoder.cells), B, self.num_nodes, self.rnn,
device=source.device
)
enc_out, enc_states = self.encoder(source[..., :1], h0, supports_seq)
last = enc_out[:, -1]
go = last
h_prev = enc_states
outputs = []
for t in range(self.horizon):
step_supports = supports_seq[-1] if supports_seq else []
go, h_prev = self.decoder(go, h_prev, step_supports)
pred = self.proj(go)
outputs.append(pred)
if self.training and target is not None:
thresh = self.cl_decay / (self.cl_decay + math.exp(batches_seen/self.cl_decay))
if torch.rand(1).item() < thresh:
go = target[:, t, :, :1]
else:
go = pred
return torch.stack(outputs, dim=1)