73 lines
3.3 KiB
Python
Executable File
73 lines
3.3 KiB
Python
Executable File
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
from model.STGNCDE import controldiffeq
|
|
from model.STGNCDE.vector_fields import *
|
|
|
|
class NeuralGCDE(nn.Module):
|
|
def __init__(self, args, func_f, func_g, input_channels, hidden_channels, output_channels, initial, device, atol, rtol, solver):
|
|
super(NeuralGCDE, self).__init__()
|
|
self.num_node = args['num_nodes']
|
|
self.input_dim = input_channels
|
|
self.hidden_dim = hidden_channels
|
|
self.output_dim = output_channels
|
|
self.horizon = args['horizon']
|
|
self.num_layers = args['num_layers']
|
|
|
|
self.default_graph = args['default_graph']
|
|
self.node_embeddings = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
|
|
|
|
self.func_f = func_f
|
|
self.func_g = func_g
|
|
self.solver = solver
|
|
self.atol = atol
|
|
self.rtol = rtol
|
|
|
|
#predictor
|
|
self.end_conv = nn.Conv2d(1, args['horizon'] * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
|
|
|
self.init_type = 'fc'
|
|
if self.init_type == 'fc':
|
|
self.initial_h = torch.nn.Linear(self.input_dim, self.hidden_dim)
|
|
self.initial_z = torch.nn.Linear(self.input_dim, self.hidden_dim)
|
|
elif self.init_type == 'conv':
|
|
self.start_conv_h = nn.Conv2d(in_channels=input_channels,
|
|
out_channels=hidden_channels,
|
|
kernel_size=(1,1))
|
|
self.start_conv_z = nn.Conv2d(in_channels=input_channels,
|
|
out_channels=hidden_channels,
|
|
kernel_size=(1,1))
|
|
|
|
def forward(self, times, coeffs):
|
|
# source: B, T_1, N, D
|
|
# target: B, T_2, N, D
|
|
|
|
spline = controldiffeq.NaturalCubicSpline(times, coeffs)
|
|
if self.init_type == 'fc':
|
|
h0 = self.initial_h(spline.evaluate(times[0]))
|
|
z0 = self.initial_z(spline.evaluate(times[0]))
|
|
elif self.init_type == 'conv':
|
|
h0 = self.start_conv_h(spline.evaluate(times[0]).transpose(1,2).unsqueeze(-1)).transpose(1,2).squeeze()
|
|
z0 = self.start_conv_z(spline.evaluate(times[0]).transpose(1,2).unsqueeze(-1)).transpose(1,2).squeeze()
|
|
|
|
z_t = controldiffeq.cdeint_gde_dev(dX_dt=spline.derivative, #dh_dt
|
|
h0=h0,
|
|
z0=z0,
|
|
func_f=self.func_f,
|
|
func_g=self.func_g,
|
|
t=times,
|
|
method=self.solver,
|
|
atol=self.atol,
|
|
rtol=self.rtol)
|
|
|
|
# init_state = self.encoder.init_hidden(source.shape[0])
|
|
# output, _ = self.encoder(source, init_state, self.node_embeddings) #B, T, N, hidden
|
|
# output = output[:, -1:, :, :] #B, 1, N, hidden
|
|
z_T = z_t[-1:,...].transpose(0,1)
|
|
|
|
#CNN based predictor
|
|
output = self.end_conv(z_T) #B, T*C, N, 1
|
|
output = output.squeeze(-1).reshape(-1, self.horizon, self.output_dim, self.num_node)
|
|
output = output.permute(0, 1, 3, 2) #B, T, N, C
|
|
|
|
return output |