import math import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple class HyperMLP(nn.Module): """General hypernetwork with configurable hidden dims""" def __init__(self, in_dim: int, out_dim: int, hidden_dims: List[int], activation=nn.Sigmoid): super().__init__() layers = [] dims = [in_dim] + hidden_dims + [out_dim] for i in range(len(dims)-1): layers.append(nn.Linear(dims[i], dims[i+1])) if i < len(dims)-2: layers.append(activation()) self.net = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) class ChebConv(nn.Module): """Chebyshev graph convolution supporting learnable weights per node""" def __init__( self, in_channels: int, out_channels: int, cheb_k: int, embed_dim: int, hidden_dims: List[int] # renamed to match instantiation ): super().__init__() # weight pool: [embed_dim, cheb_k*2+1, in_channels, out_channels] self.weight_pool = nn.Parameter( torch.Tensor(embed_dim, cheb_k*2+1, in_channels, out_channels) ) self.bias_pool = nn.Parameter(torch.Tensor(embed_dim, out_channels)) # hypernetwork for dynamic mix coefficients # uses hidden_dims self.hyper = HyperMLP(in_dim=in_channels + out_channels, out_dim=embed_dim, hidden_dims=hidden_dims) self.k = cheb_k self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.weight_pool) nn.init.zeros_(self.bias_pool) def forward( self, x: torch.Tensor, # [B, N, in_channels] supports: List[torch.Tensor] # list of [N, N] adjacency matrices ) -> torch.Tensor: B, N, C_in = x.shape # compute hyper coefficients per batch agg = x.mean(dim=1) # [B, in_channels] coeff = self.hyper(agg) # [B, embed_dim] # build dynamic weights and bias # weights_pool: [E, K, in, out]; coeff: [B, E] -> [B, K, in, out] W = torch.einsum('be,ekio->bkio', coeff, self.weight_pool) # [B, K, in, out] b = torch.einsum('be,eo->bo', coeff, self.bias_pool) # [B, out] # gather K+1 supports: original and repeated diffusions x0 = x # [B, N, in_channels] x_list = [x0] for support in supports: x1 = torch.einsum('ij,bjk->bik', support, x0) x_list.append(x1) xk = x1 for _ in range(2, self.k+1): xk = torch.einsum('ij,bjk->bik', support, xk) x_list.append(xk) # stack to [B, N, K+1, in] x_stack = torch.stack(x_list, dim=2) # apply weights out = torch.einsum('bnki,bkio->bno', x_stack, W) + b.unsqueeze(1) return out class PDG2SeqCell(nn.Module): def __init__( self, node_num: int, in_dim: int, hidden_dim: int, cheb_k: int, embed_dim: int, time_dim: int ): super().__init__() self.hidden_dim = hidden_dim merge_dim = in_dim + hidden_dim # gates and candidate use ChebConv with hidden_dims=[time_dim] self.conv_gate = ChebConv(merge_dim, 2*hidden_dim, cheb_k, embed_dim, hidden_dims=[time_dim]) self.conv_cand = ChebConv(merge_dim, hidden_dim, cheb_k, embed_dim, hidden_dims=[time_dim]) def forward( self, x: torch.Tensor, # [B, N, in_dim] h_prev: torch.Tensor, # [B, N, hidden_dim] supports: List[torch.Tensor] # dynamic supports ) -> torch.Tensor: merged = torch.cat([x, h_prev], dim=-1) gates = self.conv_gate(merged, supports) z, r = gates.chunk(2, dim=-1) z, r = torch.sigmoid(z), torch.sigmoid(r) merged_cand = torch.cat([x, z * h_prev], dim=-1) h_tilde = torch.tanh(self.conv_cand(merged_cand, supports)) h = r * h_prev + (1 - r) * h_tilde return h def init_hidden(self, batch_size: int) -> torch.Tensor: return torch.zeros(batch_size, self.node_num, self.hidden_dim, device=next(self.parameters()).device) class PDG2SeqEncoder(nn.Module): def __init__( self, node_num: int, input_dim: int, hidden_dim: int, cheb_k: int, embed_dim: int, time_dim: int, num_layers: int = 1 ): super().__init__() self.cells = nn.ModuleList([ PDG2SeqCell(node_num, input_dim if i==0 else hidden_dim, hidden_dim, cheb_k, embed_dim, time_dim) for i in range(num_layers) ]) def forward( self, x_seq: torch.Tensor, # [B, T, N, in_dim] h0: torch.Tensor, # [num_layers, B, N, hidden] supports_seq: List[List[torch.Tensor]] ) -> Tuple[torch.Tensor, torch.Tensor]: B, T, N, _ = x_seq.shape states = [] inputs = x_seq for layer, cell in enumerate(self.cells): h = h0[layer] outs = [] for t in range(T): h = cell(inputs[:, t], h, supports_seq[t]) outs.append(h) inputs = torch.stack(outs, dim=1) states.append(h) return inputs, torch.stack(states, dim=0) class PDG2SeqDecoder(nn.Module): def __init__( self, node_num: int, input_dim: int, hidden_dim: int, cheb_k: int, embed_dim: int, time_dim: int, num_layers: int = 1 ): super().__init__() self.cells = nn.ModuleList([ PDG2SeqCell(node_num, input_dim if i==0 else hidden_dim, hidden_dim, cheb_k, embed_dim, time_dim) for i in range(num_layers) ]) def forward( self, go: torch.Tensor, # [B, N, input_dim] h_prev: torch.Tensor, # [num_layers, B, N, hidden] supports: List[torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: new_states = [] inp = go for layer, cell in enumerate(self.cells): h = cell(inp, h_prev[layer], supports) new_states.append(h) inp = h return inp, torch.stack(new_states, dim=0) class PDG2Seq(nn.Module): def __init__(self, args): super().__init__() self.num_nodes = args['num_nodes'] self.horizon = args['horizon'] self.rnn = args['rnn_units'] self.encoder = PDG2SeqEncoder( self.num_nodes, args['input_dim'], args['rnn_units'], args['cheb_k'], args['embed_dim'], args['time_dim'], args['num_layers'] ) self.decoder = PDG2SeqDecoder( self.num_nodes, args['output_dim'], args['rnn_units'], args['cheb_k'], args['embed_dim'], args['time_dim'], args['num_layers'] ) self.proj = nn.Linear(args['rnn_units'], args['output_dim']) self.time_emb = nn.Embedding(288, args['time_dim']) self.week_emb = nn.Embedding(7, args['time_dim']) self.cl_decay = args['lr_decay_step'] def forward( self, source: torch.Tensor, # [B, T1, N, D] target: torch.Tensor = None, batches_seen: int = 0, supports_seq: List[List[torch.Tensor]] = None, ) -> torch.Tensor: # time embeddings t_idx = (source[..., -2] * 287).long() w_idx = source[..., -1].long() # can extend embedding combination supports_seq = supports_seq or [[] for _ in range(source.size(1))] B = source.size(0) # init hidden h0 = torch.zeros( len(self.encoder.cells), B, self.num_nodes, self.rnn, device=source.device ) enc_out, enc_states = self.encoder(source[..., :1], h0, supports_seq) last = enc_out[:, -1] go = last h_prev = enc_states outputs = [] for t in range(self.horizon): step_supports = supports_seq[-1] if supports_seq else [] go, h_prev = self.decoder(go, h_prev, step_supports) pred = self.proj(go) outputs.append(pred) if self.training and target is not None: thresh = self.cl_decay / (self.cl_decay + math.exp(batches_seen/self.cl_decay)) if torch.rand(1).item() < thresh: go = target[:, t, :, :1] else: go = pred return torch.stack(outputs, dim=1)