120 lines
5.3 KiB
Python
Executable File
120 lines
5.3 KiB
Python
Executable File
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
|
|
import model.STGNRDE.torchcde as torchcde
|
|
from model.STGNRDE.vector_fields import *
|
|
|
|
class NeuralGCDE(nn.Module):
|
|
def __init__(self, args, func_f, func_g, input_channels, hidden_channels, output_channels, initial, device, atol, rtol, solver):
|
|
super(NeuralGCDE, self).__init__()
|
|
self.num_node = args['num_nodes']
|
|
self.input_dim = input_channels
|
|
self.hidden_dim = hidden_channels
|
|
self.output_dim = output_channels
|
|
self.horizon = args['horizon']
|
|
self.num_layers = args['num_layers']
|
|
|
|
# defaults for NRDE runtime options
|
|
self.model_type = args.get('model_type', 'rde') # {'rde', 'rde2'}
|
|
self.emb_opt = args.get('emb_opt', False)
|
|
self.interpolation = args.get('interpolation', 'cubic') # {'cubic','linear'}
|
|
self.adp_opt = args.get('adp_opt', False)
|
|
|
|
self.default_graph = args['default_graph']
|
|
self.node_embeddings = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
|
|
|
|
self.func_f = func_f
|
|
self.func_g = func_g
|
|
self.solver = solver
|
|
self.atol = atol
|
|
self.rtol = rtol
|
|
|
|
#predictor
|
|
self.end_conv = nn.Conv2d(1, args['horizon'] * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
|
|
|
self.init_type = 'fc'
|
|
if self.init_type == 'fc':
|
|
self.initial_h = torch.nn.Linear(self.input_dim, self.hidden_dim)
|
|
self.initial_z = torch.nn.Linear(self.input_dim, self.hidden_dim)
|
|
elif self.init_type == 'conv':
|
|
self.start_conv_h = nn.Conv2d(in_channels=input_channels,
|
|
out_channels=hidden_channels,
|
|
kernel_size=(1,1))
|
|
self.start_conv_z = nn.Conv2d(in_channels=input_channels,
|
|
out_channels=hidden_channels,
|
|
kernel_size=(1,1))
|
|
|
|
# optional projection for adaptive pooling path
|
|
if self.adp_opt:
|
|
self.proj = nn.Linear(self.hidden_dim, 1)
|
|
|
|
def red_emb(self, coeffs):
|
|
# no-op reduction by default
|
|
return coeffs
|
|
|
|
def forward(self, times, coeffs):
|
|
# source: B, T_1, N, D
|
|
# target: B, T_2, N, D
|
|
# times = torch.linspace(0, len(times)-1, coeffs.size(-2)).to(coeffs.device)
|
|
# times = torch.linspace(0, coeffs.size(-1), coeffs.size(-2)).to(coeffs.device)
|
|
# times = torch.linspace(0, coeffs.size(-2)-1, coeffs.size(-2)).to(coeffs.device)
|
|
# import pdb; pdb.set_trace()
|
|
if self.emb_opt == True:
|
|
coeffs = self.red_emb(coeffs)
|
|
# Adapt coeffs from tuple/list (a, b, two_c, three_d) to concatenated tensor if necessary
|
|
if isinstance(coeffs, (list, tuple)):
|
|
coeffs = torch.cat(coeffs, dim=-1)
|
|
if self.interpolation == 'cubic':
|
|
X = torchcde.CubicSpline(coeffs)
|
|
elif self.interpolation == 'linear':
|
|
X = torchcde.LinearInterpolation(coeffs)
|
|
X0 = X.evaluate(X.interval[0])
|
|
|
|
if self.init_type == 'fc':
|
|
h0 = self.initial_h(X0)
|
|
z0 = self.initial_z(X0)
|
|
# z0 = self.initial_z(h0)
|
|
elif self.init_type == 'conv':
|
|
h0 = self.start_conv_h(X0.transpose(1, 2).unsqueeze(-1)).transpose(1, 2).squeeze()
|
|
z0 = self.start_conv_z(X0.transpose(1, 2).unsqueeze(-1)).transpose(1, 2).squeeze()
|
|
|
|
if self.model_type == 'rde':
|
|
z_T = torchcde.cdeint(X=X,
|
|
func=self.func_g,
|
|
z0=z0,
|
|
t=times,
|
|
adjoint=True,
|
|
method=self.solver
|
|
)
|
|
elif self.model_type == 'rde2':
|
|
step_size = (X.grid_points[1:] - X.grid_points[:-1]).min()
|
|
# adjoint_params = tuple(self.func_f.parameters()) + tuple(self.func_g.parameters()) + (coeffs,)
|
|
|
|
z_T = torchcde.cdeint_custom(X=X,
|
|
func_f=self.func_f,
|
|
func_g=self.func_g,
|
|
h0=h0,
|
|
z0=z0,
|
|
t=times,
|
|
adjoint=True,
|
|
method=self.solver,
|
|
)
|
|
|
|
if self.adp_opt == False:
|
|
z_T = z_T[:, :, -1:, :].transpose(1, 2)
|
|
else:
|
|
z_T = z_T.transpose(1, 2)
|
|
retain_score = self.proj(z_T)
|
|
retain_score = retain_score.squeeze()
|
|
retain_score = torch.sigmoid(retain_score.transpose(-1, -2))
|
|
retain_score = retain_score.unsqueeze(-1)
|
|
z_T = torch.matmul(retain_score.transpose(-1, -2), z_T.permute(0, 2, 1, 3)).transpose(1, 2)
|
|
|
|
# CNN based predictor
|
|
# output = self.end_conv(z_T.unsqueeze(-1).shape) #B, T*C, N, 1
|
|
output = self.end_conv(z_T) # B, T*C, N, 1
|
|
output = output.squeeze(-1).reshape(-1, self.horizon, self.output_dim, self.num_node)
|
|
output = output.permute(0, 1, 3, 2) # B, T, N, C
|
|
|
|
return output |