288 lines
9.0 KiB
Python
Executable File
288 lines
9.0 KiB
Python
Executable File
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import List, Tuple
|
|
|
|
|
|
class HyperMLP(nn.Module):
|
|
"""General hypernetwork with configurable hidden dims"""
|
|
|
|
def __init__(
|
|
self, in_dim: int, out_dim: int, hidden_dims: List[int], activation=nn.Sigmoid
|
|
):
|
|
super().__init__()
|
|
layers = []
|
|
dims = [in_dim] + hidden_dims + [out_dim]
|
|
for i in range(len(dims) - 1):
|
|
layers.append(nn.Linear(dims[i], dims[i + 1]))
|
|
if i < len(dims) - 2:
|
|
layers.append(activation())
|
|
self.net = nn.Sequential(*layers)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.net(x)
|
|
|
|
|
|
class ChebConv(nn.Module):
|
|
"""Chebyshev graph convolution supporting learnable weights per node"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
cheb_k: int,
|
|
embed_dim: int,
|
|
hidden_dims: List[int], # renamed to match instantiation
|
|
):
|
|
super().__init__()
|
|
# weight pool: [embed_dim, cheb_k*2+1, in_channels, out_channels]
|
|
self.weight_pool = nn.Parameter(
|
|
torch.Tensor(embed_dim, cheb_k * 2 + 1, in_channels, out_channels)
|
|
)
|
|
self.bias_pool = nn.Parameter(torch.Tensor(embed_dim, out_channels))
|
|
# hypernetwork for dynamic mix coefficients
|
|
# uses hidden_dims
|
|
self.hyper = HyperMLP(
|
|
in_dim=in_channels + out_channels,
|
|
out_dim=embed_dim,
|
|
hidden_dims=hidden_dims,
|
|
)
|
|
self.k = cheb_k
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
nn.init.xavier_uniform_(self.weight_pool)
|
|
nn.init.zeros_(self.bias_pool)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor, # [B, N, in_channels]
|
|
supports: List[torch.Tensor], # list of [N, N] adjacency matrices
|
|
) -> torch.Tensor:
|
|
B, N, C_in = x.shape
|
|
# compute hyper coefficients per batch
|
|
agg = x.mean(dim=1) # [B, in_channels]
|
|
coeff = self.hyper(agg) # [B, embed_dim]
|
|
|
|
# build dynamic weights and bias
|
|
# weights_pool: [E, K, in, out]; coeff: [B, E] -> [B, K, in, out]
|
|
W = torch.einsum("be,ekio->bkio", coeff, self.weight_pool) # [B, K, in, out]
|
|
b = torch.einsum("be,eo->bo", coeff, self.bias_pool) # [B, out]
|
|
|
|
# gather K+1 supports: original and repeated diffusions
|
|
x0 = x # [B, N, in_channels]
|
|
x_list = [x0]
|
|
for support in supports:
|
|
x1 = torch.einsum("ij,bjk->bik", support, x0)
|
|
x_list.append(x1)
|
|
xk = x1
|
|
for _ in range(2, self.k + 1):
|
|
xk = torch.einsum("ij,bjk->bik", support, xk)
|
|
x_list.append(xk)
|
|
# stack to [B, N, K+1, in]
|
|
x_stack = torch.stack(x_list, dim=2)
|
|
# apply weights
|
|
out = torch.einsum("bnki,bkio->bno", x_stack, W) + b.unsqueeze(1)
|
|
return out
|
|
|
|
|
|
class PDG2SeqCell(nn.Module):
|
|
def __init__(
|
|
self,
|
|
node_num: int,
|
|
in_dim: int,
|
|
hidden_dim: int,
|
|
cheb_k: int,
|
|
embed_dim: int,
|
|
time_dim: int,
|
|
):
|
|
super().__init__()
|
|
self.hidden_dim = hidden_dim
|
|
merge_dim = in_dim + hidden_dim
|
|
# gates and candidate use ChebConv with hidden_dims=[time_dim]
|
|
self.conv_gate = ChebConv(
|
|
merge_dim, 2 * hidden_dim, cheb_k, embed_dim, hidden_dims=[time_dim]
|
|
)
|
|
self.conv_cand = ChebConv(
|
|
merge_dim, hidden_dim, cheb_k, embed_dim, hidden_dims=[time_dim]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor, # [B, N, in_dim]
|
|
h_prev: torch.Tensor, # [B, N, hidden_dim]
|
|
supports: List[torch.Tensor], # dynamic supports
|
|
) -> torch.Tensor:
|
|
merged = torch.cat([x, h_prev], dim=-1)
|
|
gates = self.conv_gate(merged, supports)
|
|
z, r = gates.chunk(2, dim=-1)
|
|
z, r = torch.sigmoid(z), torch.sigmoid(r)
|
|
merged_cand = torch.cat([x, z * h_prev], dim=-1)
|
|
h_tilde = torch.tanh(self.conv_cand(merged_cand, supports))
|
|
h = r * h_prev + (1 - r) * h_tilde
|
|
return h
|
|
|
|
def init_hidden(self, batch_size: int) -> torch.Tensor:
|
|
return torch.zeros(
|
|
batch_size,
|
|
self.node_num,
|
|
self.hidden_dim,
|
|
device=next(self.parameters()).device,
|
|
)
|
|
|
|
|
|
class PDG2SeqEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
node_num: int,
|
|
input_dim: int,
|
|
hidden_dim: int,
|
|
cheb_k: int,
|
|
embed_dim: int,
|
|
time_dim: int,
|
|
num_layers: int = 1,
|
|
):
|
|
super().__init__()
|
|
self.cells = nn.ModuleList(
|
|
[
|
|
PDG2SeqCell(
|
|
node_num,
|
|
input_dim if i == 0 else hidden_dim,
|
|
hidden_dim,
|
|
cheb_k,
|
|
embed_dim,
|
|
time_dim,
|
|
)
|
|
for i in range(num_layers)
|
|
]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x_seq: torch.Tensor, # [B, T, N, in_dim]
|
|
h0: torch.Tensor, # [num_layers, B, N, hidden]
|
|
supports_seq: List[List[torch.Tensor]],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
B, T, N, _ = x_seq.shape
|
|
states = []
|
|
inputs = x_seq
|
|
for layer, cell in enumerate(self.cells):
|
|
h = h0[layer]
|
|
outs = []
|
|
for t in range(T):
|
|
h = cell(inputs[:, t], h, supports_seq[t])
|
|
outs.append(h)
|
|
inputs = torch.stack(outs, dim=1)
|
|
states.append(h)
|
|
return inputs, torch.stack(states, dim=0)
|
|
|
|
|
|
class PDG2SeqDecoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
node_num: int,
|
|
input_dim: int,
|
|
hidden_dim: int,
|
|
cheb_k: int,
|
|
embed_dim: int,
|
|
time_dim: int,
|
|
num_layers: int = 1,
|
|
):
|
|
super().__init__()
|
|
self.cells = nn.ModuleList(
|
|
[
|
|
PDG2SeqCell(
|
|
node_num,
|
|
input_dim if i == 0 else hidden_dim,
|
|
hidden_dim,
|
|
cheb_k,
|
|
embed_dim,
|
|
time_dim,
|
|
)
|
|
for i in range(num_layers)
|
|
]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
go: torch.Tensor, # [B, N, input_dim]
|
|
h_prev: torch.Tensor, # [num_layers, B, N, hidden]
|
|
supports: List[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
new_states = []
|
|
inp = go
|
|
for layer, cell in enumerate(self.cells):
|
|
h = cell(inp, h_prev[layer], supports)
|
|
new_states.append(h)
|
|
inp = h
|
|
return inp, torch.stack(new_states, dim=0)
|
|
|
|
|
|
class PDG2Seq(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.num_nodes = args["num_nodes"]
|
|
self.horizon = args["horizon"]
|
|
self.rnn = args["rnn_units"]
|
|
self.encoder = PDG2SeqEncoder(
|
|
self.num_nodes,
|
|
args["input_dim"],
|
|
args["rnn_units"],
|
|
args["cheb_k"],
|
|
args["embed_dim"],
|
|
args["time_dim"],
|
|
args["num_layers"],
|
|
)
|
|
self.decoder = PDG2SeqDecoder(
|
|
self.num_nodes,
|
|
args["output_dim"],
|
|
args["rnn_units"],
|
|
args["cheb_k"],
|
|
args["embed_dim"],
|
|
args["time_dim"],
|
|
args["num_layers"],
|
|
)
|
|
self.proj = nn.Linear(args["rnn_units"], args["output_dim"])
|
|
self.time_emb = nn.Embedding(288, args["time_dim"])
|
|
self.week_emb = nn.Embedding(7, args["time_dim"])
|
|
self.cl_decay = args["lr_decay_step"]
|
|
|
|
def forward(
|
|
self,
|
|
source: torch.Tensor, # [B, T1, N, D]
|
|
target: torch.Tensor = None,
|
|
batches_seen: int = 0,
|
|
supports_seq: List[List[torch.Tensor]] = None,
|
|
) -> torch.Tensor:
|
|
# time embeddings
|
|
t_idx = (source[..., -2] * 287).long()
|
|
w_idx = source[..., -1].long()
|
|
# can extend embedding combination
|
|
supports_seq = supports_seq or [[] for _ in range(source.size(1))]
|
|
B = source.size(0)
|
|
# init hidden
|
|
h0 = torch.zeros(
|
|
len(self.encoder.cells), B, self.num_nodes, self.rnn, device=source.device
|
|
)
|
|
enc_out, enc_states = self.encoder(source[..., :1], h0, supports_seq)
|
|
last = enc_out[:, -1]
|
|
go = last
|
|
h_prev = enc_states
|
|
outputs = []
|
|
for t in range(self.horizon):
|
|
step_supports = supports_seq[-1] if supports_seq else []
|
|
go, h_prev = self.decoder(go, h_prev, step_supports)
|
|
pred = self.proj(go)
|
|
outputs.append(pred)
|
|
if self.training and target is not None:
|
|
thresh = self.cl_decay / (
|
|
self.cl_decay + math.exp(batches_seen / self.cl_decay)
|
|
)
|
|
if torch.rand(1).item() < thresh:
|
|
go = target[:, t, :, :1]
|
|
else:
|
|
go = pred
|
|
return torch.stack(outputs, dim=1)
|