import numpy as np import torch from torch import nn import torch.nn.functional as F # Whether use adjoint method or not. adjoint = False if adjoint: from torchdiffeq import odeint_adjoint as odeint else: from torchdiffeq import odeint # Define the ODE function. # Input: # --- t: A tensor with shape [], meaning the current time. # --- x: A tensor with shape [#batches, dims], meaning the value of x at t. # Output: # --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t. class ODEFunc(nn.Module): def __init__(self, feature_dim, temporal_dim, adj): super(ODEFunc, self).__init__() self.adj = adj self.x0 = None self.alpha = nn.Parameter(0.8 * torch.ones(adj.shape[1])) self.beta = 0.6 self.w = nn.Parameter(torch.eye(feature_dim)) self.d = nn.Parameter(torch.zeros(feature_dim) + 1) self.w2 = nn.Parameter(torch.eye(temporal_dim)) self.d2 = nn.Parameter(torch.zeros(temporal_dim) + 1) def forward(self, t, x): alpha = torch.sigmoid(self.alpha).unsqueeze(-1).unsqueeze(-1).unsqueeze(0) xa = torch.einsum('ij, kjlm->kilm', self.adj, x) # ensure the eigenvalues to be less than 1 d = torch.clamp(self.d, min=0, max=1) w = torch.mm(self.w * d, torch.t(self.w)) xw = torch.einsum('ijkl, lm->ijkm', x, w) d2 = torch.clamp(self.d2, min=0, max=1) w2 = torch.mm(self.w2 * d2, torch.t(self.w2)) xw2 = torch.einsum('ijkl, km->ijml', x, w2) f = alpha / 2 * xa - x + xw - x + xw2 - x + self.x0 return f class ODEblock(nn.Module): def __init__(self, odefunc, t=torch.tensor([0,1])): super(ODEblock, self).__init__() self.t = t self.odefunc = odefunc def set_x0(self, x0): self.odefunc.x0 = x0.clone().detach() def forward(self, x): t = self.t.type_as(x) z = odeint(self.odefunc, x, t, method='euler')[1] return z # Define the ODEGCN model. class ODEG(nn.Module): def __init__(self, feature_dim, temporal_dim, adj, time): super(ODEG, self).__init__() self.odeblock = ODEblock(ODEFunc(feature_dim, temporal_dim, adj), t=torch.tensor([0, time])) def forward(self, x): self.odeblock.set_x0(x) z = self.odeblock(x) return F.relu(z)