TrafficWheel/model/STGNRDE/GRDE.py

120 lines
5.3 KiB
Python
Executable File

import torch
import torch.nn.functional as F
import torch.nn as nn
import model.STGNRDE.torchcde as torchcde
from model.STGNRDE.vector_fields import *
class NeuralGCDE(nn.Module):
def __init__(self, args, func_f, func_g, input_channels, hidden_channels, output_channels, initial, device, atol, rtol, solver):
super(NeuralGCDE, self).__init__()
self.num_node = args['num_nodes']
self.input_dim = input_channels
self.hidden_dim = hidden_channels
self.output_dim = output_channels
self.horizon = args['horizon']
self.num_layers = args['num_layers']
# defaults for NRDE runtime options
self.model_type = args.get('model_type', 'rde') # {'rde', 'rde2'}
self.emb_opt = args.get('emb_opt', False)
self.interpolation = args.get('interpolation', 'cubic') # {'cubic','linear'}
self.adp_opt = args.get('adp_opt', False)
self.default_graph = args['default_graph']
self.node_embeddings = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
self.func_f = func_f
self.func_g = func_g
self.solver = solver
self.atol = atol
self.rtol = rtol
#predictor
self.end_conv = nn.Conv2d(1, args['horizon'] * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
self.init_type = 'fc'
if self.init_type == 'fc':
self.initial_h = torch.nn.Linear(self.input_dim, self.hidden_dim)
self.initial_z = torch.nn.Linear(self.input_dim, self.hidden_dim)
elif self.init_type == 'conv':
self.start_conv_h = nn.Conv2d(in_channels=input_channels,
out_channels=hidden_channels,
kernel_size=(1,1))
self.start_conv_z = nn.Conv2d(in_channels=input_channels,
out_channels=hidden_channels,
kernel_size=(1,1))
# optional projection for adaptive pooling path
if self.adp_opt:
self.proj = nn.Linear(self.hidden_dim, 1)
def red_emb(self, coeffs):
# no-op reduction by default
return coeffs
def forward(self, times, coeffs):
# source: B, T_1, N, D
# target: B, T_2, N, D
# times = torch.linspace(0, len(times)-1, coeffs.size(-2)).to(coeffs.device)
# times = torch.linspace(0, coeffs.size(-1), coeffs.size(-2)).to(coeffs.device)
# times = torch.linspace(0, coeffs.size(-2)-1, coeffs.size(-2)).to(coeffs.device)
# import pdb; pdb.set_trace()
if self.emb_opt == True:
coeffs = self.red_emb(coeffs)
# Adapt coeffs from tuple/list (a, b, two_c, three_d) to concatenated tensor if necessary
if isinstance(coeffs, (list, tuple)):
coeffs = torch.cat(coeffs, dim=-1)
if self.interpolation == 'cubic':
X = torchcde.CubicSpline(coeffs)
elif self.interpolation == 'linear':
X = torchcde.LinearInterpolation(coeffs)
X0 = X.evaluate(X.interval[0])
if self.init_type == 'fc':
h0 = self.initial_h(X0)
z0 = self.initial_z(X0)
# z0 = self.initial_z(h0)
elif self.init_type == 'conv':
h0 = self.start_conv_h(X0.transpose(1, 2).unsqueeze(-1)).transpose(1, 2).squeeze()
z0 = self.start_conv_z(X0.transpose(1, 2).unsqueeze(-1)).transpose(1, 2).squeeze()
if self.model_type == 'rde':
z_T = torchcde.cdeint(X=X,
func=self.func_g,
z0=z0,
t=times,
adjoint=True,
method=self.solver
)
elif self.model_type == 'rde2':
step_size = (X.grid_points[1:] - X.grid_points[:-1]).min()
# adjoint_params = tuple(self.func_f.parameters()) + tuple(self.func_g.parameters()) + (coeffs,)
z_T = torchcde.cdeint_custom(X=X,
func_f=self.func_f,
func_g=self.func_g,
h0=h0,
z0=z0,
t=times,
adjoint=True,
method=self.solver,
)
if self.adp_opt == False:
z_T = z_T[:, :, -1:, :].transpose(1, 2)
else:
z_T = z_T.transpose(1, 2)
retain_score = self.proj(z_T)
retain_score = retain_score.squeeze()
retain_score = torch.sigmoid(retain_score.transpose(-1, -2))
retain_score = retain_score.unsqueeze(-1)
z_T = torch.matmul(retain_score.transpose(-1, -2), z_T.permute(0, 2, 1, 3)).transpose(1, 2)
# CNN based predictor
# output = self.end_conv(z_T.unsqueeze(-1).shape) #B, T*C, N, 1
output = self.end_conv(z_T) # B, T*C, N, 1
output = output.squeeze(-1).reshape(-1, self.horizon, self.output_dim, self.num_node)
output = output.permute(0, 1, 3, 2) # B, T, N, C
return output