TrafficWheel/model/STGNCDE/GCDE.py

115 lines
3.5 KiB
Python
Executable File

import torch
import torch.nn.functional as F
import torch.nn as nn
from model.STGNCDE import controldiffeq
from model.STGNCDE.vector_fields import *
class NeuralGCDE(nn.Module):
def __init__(
self,
args,
func_f,
func_g,
input_channels,
hidden_channels,
output_channels,
initial,
device,
atol,
rtol,
solver,
):
super(NeuralGCDE, self).__init__()
self.num_node = args["num_nodes"]
self.input_dim = input_channels
self.hidden_dim = hidden_channels
self.output_dim = output_channels
self.horizon = args["horizon"]
self.num_layers = args["num_layers"]
self.default_graph = args["default_graph"]
self.node_embeddings = nn.Parameter(
torch.randn(self.num_node, args["embed_dim"]), requires_grad=True
)
self.func_f = func_f
self.func_g = func_g
self.solver = solver
self.atol = atol
self.rtol = rtol
# predictor
self.end_conv = nn.Conv2d(
1,
args["horizon"] * self.output_dim,
kernel_size=(1, self.hidden_dim),
bias=True,
)
self.init_type = "fc"
if self.init_type == "fc":
self.initial_h = torch.nn.Linear(self.input_dim, self.hidden_dim)
self.initial_z = torch.nn.Linear(self.input_dim, self.hidden_dim)
elif self.init_type == "conv":
self.start_conv_h = nn.Conv2d(
in_channels=input_channels,
out_channels=hidden_channels,
kernel_size=(1, 1),
)
self.start_conv_z = nn.Conv2d(
in_channels=input_channels,
out_channels=hidden_channels,
kernel_size=(1, 1),
)
def forward(self, times, coeffs):
# source: B, T_1, N, D
# target: B, T_2, N, D
spline = controldiffeq.NaturalCubicSpline(times, coeffs)
if self.init_type == "fc":
h0 = self.initial_h(spline.evaluate(times[0]))
z0 = self.initial_z(spline.evaluate(times[0]))
elif self.init_type == "conv":
h0 = (
self.start_conv_h(
spline.evaluate(times[0]).transpose(1, 2).unsqueeze(-1)
)
.transpose(1, 2)
.squeeze()
)
z0 = (
self.start_conv_z(
spline.evaluate(times[0]).transpose(1, 2).unsqueeze(-1)
)
.transpose(1, 2)
.squeeze()
)
z_t = controldiffeq.cdeint_gde_dev(
dX_dt=spline.derivative, # dh_dt
h0=h0,
z0=z0,
func_f=self.func_f,
func_g=self.func_g,
t=times,
method=self.solver,
atol=self.atol,
rtol=self.rtol,
)
# init_state = self.encoder.init_hidden(source.shape[0])
# output, _ = self.encoder(source, init_state, self.node_embeddings) #B, T, N, hidden
# output = output[:, -1:, :, :] #B, 1, N, hidden
z_T = z_t[-1:, ...].transpose(0, 1)
# CNN based predictor
output = self.end_conv(z_T) # B, T*C, N, 1
output = output.squeeze(-1).reshape(
-1, self.horizon, self.output_dim, self.num_node
)
output = output.permute(0, 1, 3, 2) # B, T, N, C
return output