244 lines
8.4 KiB
Python
Executable File
244 lines
8.4 KiB
Python
Executable File
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import List, Tuple
|
|
|
|
|
|
class HyperMLP(nn.Module):
|
|
"""General hypernetwork with configurable hidden dims"""
|
|
def __init__(self, in_dim: int, out_dim: int, hidden_dims: List[int], activation=nn.Sigmoid):
|
|
super().__init__()
|
|
layers = []
|
|
dims = [in_dim] + hidden_dims + [out_dim]
|
|
for i in range(len(dims)-1):
|
|
layers.append(nn.Linear(dims[i], dims[i+1]))
|
|
if i < len(dims)-2:
|
|
layers.append(activation())
|
|
self.net = nn.Sequential(*layers)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.net(x)
|
|
|
|
|
|
class ChebConv(nn.Module):
|
|
"""Chebyshev graph convolution supporting learnable weights per node"""
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
cheb_k: int,
|
|
embed_dim: int,
|
|
hidden_dims: List[int] # renamed to match instantiation
|
|
):
|
|
super().__init__()
|
|
# weight pool: [embed_dim, cheb_k*2+1, in_channels, out_channels]
|
|
self.weight_pool = nn.Parameter(
|
|
torch.Tensor(embed_dim, cheb_k*2+1, in_channels, out_channels)
|
|
)
|
|
self.bias_pool = nn.Parameter(torch.Tensor(embed_dim, out_channels))
|
|
# hypernetwork for dynamic mix coefficients
|
|
# uses hidden_dims
|
|
self.hyper = HyperMLP(in_dim=in_channels + out_channels, out_dim=embed_dim, hidden_dims=hidden_dims)
|
|
self.k = cheb_k
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
nn.init.xavier_uniform_(self.weight_pool)
|
|
nn.init.zeros_(self.bias_pool)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor, # [B, N, in_channels]
|
|
supports: List[torch.Tensor] # list of [N, N] adjacency matrices
|
|
) -> torch.Tensor:
|
|
B, N, C_in = x.shape
|
|
# compute hyper coefficients per batch
|
|
agg = x.mean(dim=1) # [B, in_channels]
|
|
coeff = self.hyper(agg) # [B, embed_dim]
|
|
|
|
# build dynamic weights and bias
|
|
# weights_pool: [E, K, in, out]; coeff: [B, E] -> [B, K, in, out]
|
|
W = torch.einsum('be,ekio->bkio', coeff, self.weight_pool) # [B, K, in, out]
|
|
b = torch.einsum('be,eo->bo', coeff, self.bias_pool) # [B, out]
|
|
|
|
# gather K+1 supports: original and repeated diffusions
|
|
x0 = x # [B, N, in_channels]
|
|
x_list = [x0]
|
|
for support in supports:
|
|
x1 = torch.einsum('ij,bjk->bik', support, x0)
|
|
x_list.append(x1)
|
|
xk = x1
|
|
for _ in range(2, self.k+1):
|
|
xk = torch.einsum('ij,bjk->bik', support, xk)
|
|
x_list.append(xk)
|
|
# stack to [B, N, K+1, in]
|
|
x_stack = torch.stack(x_list, dim=2)
|
|
# apply weights
|
|
out = torch.einsum('bnki,bkio->bno', x_stack, W) + b.unsqueeze(1)
|
|
return out
|
|
|
|
|
|
class PDG2SeqCell(nn.Module):
|
|
def __init__(
|
|
self,
|
|
node_num: int,
|
|
in_dim: int,
|
|
hidden_dim: int,
|
|
cheb_k: int,
|
|
embed_dim: int,
|
|
time_dim: int
|
|
):
|
|
super().__init__()
|
|
self.hidden_dim = hidden_dim
|
|
merge_dim = in_dim + hidden_dim
|
|
# gates and candidate use ChebConv with hidden_dims=[time_dim]
|
|
self.conv_gate = ChebConv(merge_dim, 2*hidden_dim, cheb_k, embed_dim, hidden_dims=[time_dim])
|
|
self.conv_cand = ChebConv(merge_dim, hidden_dim, cheb_k, embed_dim, hidden_dims=[time_dim])
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor, # [B, N, in_dim]
|
|
h_prev: torch.Tensor, # [B, N, hidden_dim]
|
|
supports: List[torch.Tensor] # dynamic supports
|
|
) -> torch.Tensor:
|
|
merged = torch.cat([x, h_prev], dim=-1)
|
|
gates = self.conv_gate(merged, supports)
|
|
z, r = gates.chunk(2, dim=-1)
|
|
z, r = torch.sigmoid(z), torch.sigmoid(r)
|
|
merged_cand = torch.cat([x, z * h_prev], dim=-1)
|
|
h_tilde = torch.tanh(self.conv_cand(merged_cand, supports))
|
|
h = r * h_prev + (1 - r) * h_tilde
|
|
return h
|
|
|
|
def init_hidden(self, batch_size: int) -> torch.Tensor:
|
|
return torch.zeros(batch_size, self.node_num, self.hidden_dim, device=next(self.parameters()).device)
|
|
|
|
|
|
class PDG2SeqEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
node_num: int,
|
|
input_dim: int,
|
|
hidden_dim: int,
|
|
cheb_k: int,
|
|
embed_dim: int,
|
|
time_dim: int,
|
|
num_layers: int = 1
|
|
):
|
|
super().__init__()
|
|
self.cells = nn.ModuleList([
|
|
PDG2SeqCell(node_num, input_dim if i==0 else hidden_dim,
|
|
hidden_dim, cheb_k, embed_dim, time_dim)
|
|
for i in range(num_layers)
|
|
])
|
|
|
|
def forward(
|
|
self,
|
|
x_seq: torch.Tensor, # [B, T, N, in_dim]
|
|
h0: torch.Tensor, # [num_layers, B, N, hidden]
|
|
supports_seq: List[List[torch.Tensor]]
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
B, T, N, _ = x_seq.shape
|
|
states = []
|
|
inputs = x_seq
|
|
for layer, cell in enumerate(self.cells):
|
|
h = h0[layer]
|
|
outs = []
|
|
for t in range(T):
|
|
h = cell(inputs[:, t], h, supports_seq[t])
|
|
outs.append(h)
|
|
inputs = torch.stack(outs, dim=1)
|
|
states.append(h)
|
|
return inputs, torch.stack(states, dim=0)
|
|
|
|
|
|
class PDG2SeqDecoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
node_num: int,
|
|
input_dim: int,
|
|
hidden_dim: int,
|
|
cheb_k: int,
|
|
embed_dim: int,
|
|
time_dim: int,
|
|
num_layers: int = 1
|
|
):
|
|
super().__init__()
|
|
self.cells = nn.ModuleList([
|
|
PDG2SeqCell(node_num, input_dim if i==0 else hidden_dim,
|
|
hidden_dim, cheb_k, embed_dim, time_dim)
|
|
for i in range(num_layers)
|
|
])
|
|
|
|
def forward(
|
|
self,
|
|
go: torch.Tensor, # [B, N, input_dim]
|
|
h_prev: torch.Tensor, # [num_layers, B, N, hidden]
|
|
supports: List[torch.Tensor]
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
new_states = []
|
|
inp = go
|
|
for layer, cell in enumerate(self.cells):
|
|
h = cell(inp, h_prev[layer], supports)
|
|
new_states.append(h)
|
|
inp = h
|
|
return inp, torch.stack(new_states, dim=0)
|
|
|
|
|
|
class PDG2Seq(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.num_nodes = args['num_nodes']
|
|
self.horizon = args['horizon']
|
|
self.rnn = args['rnn_units']
|
|
self.encoder = PDG2SeqEncoder(
|
|
self.num_nodes, args['input_dim'], args['rnn_units'],
|
|
args['cheb_k'], args['embed_dim'], args['time_dim'], args['num_layers']
|
|
)
|
|
self.decoder = PDG2SeqDecoder(
|
|
self.num_nodes, args['output_dim'], args['rnn_units'],
|
|
args['cheb_k'], args['embed_dim'], args['time_dim'], args['num_layers']
|
|
)
|
|
self.proj = nn.Linear(args['rnn_units'], args['output_dim'])
|
|
self.time_emb = nn.Embedding(288, args['time_dim'])
|
|
self.week_emb = nn.Embedding(7, args['time_dim'])
|
|
self.cl_decay = args['lr_decay_step']
|
|
|
|
def forward(
|
|
self,
|
|
source: torch.Tensor, # [B, T1, N, D]
|
|
target: torch.Tensor = None,
|
|
batches_seen: int = 0,
|
|
supports_seq: List[List[torch.Tensor]] = None,
|
|
) -> torch.Tensor:
|
|
# time embeddings
|
|
t_idx = (source[..., -2] * 287).long()
|
|
w_idx = source[..., -1].long()
|
|
# can extend embedding combination
|
|
supports_seq = supports_seq or [[] for _ in range(source.size(1))]
|
|
B = source.size(0)
|
|
# init hidden
|
|
h0 = torch.zeros(
|
|
len(self.encoder.cells), B, self.num_nodes, self.rnn,
|
|
device=source.device
|
|
)
|
|
enc_out, enc_states = self.encoder(source[..., :1], h0, supports_seq)
|
|
last = enc_out[:, -1]
|
|
go = last
|
|
h_prev = enc_states
|
|
outputs = []
|
|
for t in range(self.horizon):
|
|
step_supports = supports_seq[-1] if supports_seq else []
|
|
go, h_prev = self.decoder(go, h_prev, step_supports)
|
|
pred = self.proj(go)
|
|
outputs.append(pred)
|
|
if self.training and target is not None:
|
|
thresh = self.cl_decay / (self.cl_decay + math.exp(batches_seen/self.cl_decay))
|
|
if torch.rand(1).item() < thresh:
|
|
go = target[:, t, :, :1]
|
|
else:
|
|
go = pred
|
|
return torch.stack(outputs, dim=1)
|