155 lines
5.2 KiB
Python
Executable File
155 lines
5.2 KiB
Python
Executable File
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
|
|
import model.STGNRDE.torchcde as torchcde
|
|
from model.STGNRDE.vector_fields import *
|
|
|
|
|
|
class NeuralGCDE(nn.Module):
|
|
def __init__(
|
|
self,
|
|
args,
|
|
func_f,
|
|
func_g,
|
|
input_channels,
|
|
hidden_channels,
|
|
output_channels,
|
|
initial,
|
|
device,
|
|
atol,
|
|
rtol,
|
|
solver,
|
|
):
|
|
super(NeuralGCDE, self).__init__()
|
|
self.num_node = args["num_nodes"]
|
|
self.input_dim = input_channels
|
|
self.hidden_dim = hidden_channels
|
|
self.output_dim = output_channels
|
|
self.horizon = args["horizon"]
|
|
self.num_layers = args["num_layers"]
|
|
|
|
# defaults for NRDE runtime options
|
|
self.model_type = args.get("model_type", "rde") # {'rde', 'rde2'}
|
|
self.emb_opt = args.get("emb_opt", False)
|
|
self.interpolation = args.get("interpolation", "cubic") # {'cubic','linear'}
|
|
self.adp_opt = args.get("adp_opt", False)
|
|
|
|
self.default_graph = args["default_graph"]
|
|
self.node_embeddings = nn.Parameter(
|
|
torch.randn(self.num_node, args["embed_dim"]), requires_grad=True
|
|
)
|
|
|
|
self.func_f = func_f
|
|
self.func_g = func_g
|
|
self.solver = solver
|
|
self.atol = atol
|
|
self.rtol = rtol
|
|
|
|
# predictor
|
|
self.end_conv = nn.Conv2d(
|
|
1,
|
|
args["horizon"] * self.output_dim,
|
|
kernel_size=(1, self.hidden_dim),
|
|
bias=True,
|
|
)
|
|
|
|
self.init_type = "fc"
|
|
if self.init_type == "fc":
|
|
self.initial_h = torch.nn.Linear(self.input_dim, self.hidden_dim)
|
|
self.initial_z = torch.nn.Linear(self.input_dim, self.hidden_dim)
|
|
elif self.init_type == "conv":
|
|
self.start_conv_h = nn.Conv2d(
|
|
in_channels=input_channels,
|
|
out_channels=hidden_channels,
|
|
kernel_size=(1, 1),
|
|
)
|
|
self.start_conv_z = nn.Conv2d(
|
|
in_channels=input_channels,
|
|
out_channels=hidden_channels,
|
|
kernel_size=(1, 1),
|
|
)
|
|
|
|
# optional projection for adaptive pooling path
|
|
if self.adp_opt:
|
|
self.proj = nn.Linear(self.hidden_dim, 1)
|
|
|
|
def red_emb(self, coeffs):
|
|
# no-op reduction by default
|
|
return coeffs
|
|
|
|
def forward(self, times, coeffs):
|
|
# source: B, T_1, N, D
|
|
# target: B, T_2, N, D
|
|
# times = torch.linspace(0, len(times)-1, coeffs.size(-2)).to(coeffs.device)
|
|
# times = torch.linspace(0, coeffs.size(-1), coeffs.size(-2)).to(coeffs.device)
|
|
# times = torch.linspace(0, coeffs.size(-2)-1, coeffs.size(-2)).to(coeffs.device)
|
|
# import pdb; pdb.set_trace()
|
|
if self.emb_opt == True:
|
|
coeffs = self.red_emb(coeffs)
|
|
# Adapt coeffs from tuple/list (a, b, two_c, three_d) to concatenated tensor if necessary
|
|
if isinstance(coeffs, (list, tuple)):
|
|
coeffs = torch.cat(coeffs, dim=-1)
|
|
if self.interpolation == "cubic":
|
|
X = torchcde.CubicSpline(coeffs)
|
|
elif self.interpolation == "linear":
|
|
X = torchcde.LinearInterpolation(coeffs)
|
|
X0 = X.evaluate(X.interval[0])
|
|
|
|
if self.init_type == "fc":
|
|
h0 = self.initial_h(X0)
|
|
z0 = self.initial_z(X0)
|
|
# z0 = self.initial_z(h0)
|
|
elif self.init_type == "conv":
|
|
h0 = (
|
|
self.start_conv_h(X0.transpose(1, 2).unsqueeze(-1))
|
|
.transpose(1, 2)
|
|
.squeeze()
|
|
)
|
|
z0 = (
|
|
self.start_conv_z(X0.transpose(1, 2).unsqueeze(-1))
|
|
.transpose(1, 2)
|
|
.squeeze()
|
|
)
|
|
|
|
if self.model_type == "rde":
|
|
z_T = torchcde.cdeint(
|
|
X=X, func=self.func_g, z0=z0, t=times, adjoint=True, method=self.solver
|
|
)
|
|
elif self.model_type == "rde2":
|
|
step_size = (X.grid_points[1:] - X.grid_points[:-1]).min()
|
|
# adjoint_params = tuple(self.func_f.parameters()) + tuple(self.func_g.parameters()) + (coeffs,)
|
|
|
|
z_T = torchcde.cdeint_custom(
|
|
X=X,
|
|
func_f=self.func_f,
|
|
func_g=self.func_g,
|
|
h0=h0,
|
|
z0=z0,
|
|
t=times,
|
|
adjoint=True,
|
|
method=self.solver,
|
|
)
|
|
|
|
if self.adp_opt == False:
|
|
z_T = z_T[:, :, -1:, :].transpose(1, 2)
|
|
else:
|
|
z_T = z_T.transpose(1, 2)
|
|
retain_score = self.proj(z_T)
|
|
retain_score = retain_score.squeeze()
|
|
retain_score = torch.sigmoid(retain_score.transpose(-1, -2))
|
|
retain_score = retain_score.unsqueeze(-1)
|
|
z_T = torch.matmul(
|
|
retain_score.transpose(-1, -2), z_T.permute(0, 2, 1, 3)
|
|
).transpose(1, 2)
|
|
|
|
# CNN based predictor
|
|
# output = self.end_conv(z_T.unsqueeze(-1).shape) #B, T*C, N, 1
|
|
output = self.end_conv(z_T) # B, T*C, N, 1
|
|
output = output.squeeze(-1).reshape(
|
|
-1, self.horizon, self.output_dim, self.num_node
|
|
)
|
|
output = output.permute(0, 1, 3, 2) # B, T, N, C
|
|
|
|
return output
|