import torch import torch.nn as nn import torch.nn.functional as F class FinalTanh_f(nn.Module): def __init__( self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers ): super(FinalTanh_f, self).__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.hidden_hidden_channels = hidden_hidden_channels self.num_hidden_layers = num_hidden_layers self.linear_in = nn.Linear(hidden_channels, hidden_hidden_channels) self.linears = nn.ModuleList( torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels) for _ in range(num_hidden_layers - 1) ) self.linear_out = nn.Linear( hidden_hidden_channels, input_channels * hidden_channels ) # 32,32*4 -> # 32,32,4 def extra_repr(self): return ( "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" "".format( self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers, ) ) def forward(self, *args): z = args[0] if len(args) == 1 else args[1] z = self.linear_in(z) z = z.relu() for linear in self.linears: z = linear(z) z = z.relu() # z: torch.Size([64, 207, 32]) # self.linear_out(z): torch.Size([64, 207, 64]) z = self.linear_out(z).view( *z.shape[:-1], self.hidden_channels, self.input_channels ) z = z.tanh() return z class FinalTanh_f_prime(nn.Module): def __init__( self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers ): super(FinalTanh_f_prime, self).__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.hidden_hidden_channels = hidden_hidden_channels self.num_hidden_layers = num_hidden_layers self.linear_in = nn.Linear(hidden_channels, hidden_hidden_channels) self.linears = nn.ModuleList( torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels) for _ in range(num_hidden_layers - 1) ) # self.linear_out = nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4 self.linear_out = nn.Linear( hidden_hidden_channels, hidden_channels * hidden_channels ) # 32,32*4 -> # 32,32,4 def extra_repr(self): return ( "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" "".format( self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers, ) ) def forward(self, *args): z = args[0] if len(args) == 1 else args[1] z = self.linear_in(z) z = z.relu() for linear in self.linears: z = linear(z) z = z.relu() # z: torch.Size([64, 207, 32]) # self.linear_out(z): torch.Size([64, 207, 64]) # z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.input_channels) z = self.linear_out(z).view( *z.shape[:-1], self.hidden_channels, self.hidden_channels ) z = z.tanh() return z class FinalTanh_f2(torch.nn.Module): def __init__( self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers ): super(FinalTanh_f2, self).__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.hidden_hidden_channels = hidden_hidden_channels self.num_hidden_layers = num_hidden_layers # self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels) # self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels) # for _ in range(num_hidden_layers - 1)) # self.linear_out = torch.nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4 self.start_conv = torch.nn.Conv2d( in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=(1, 1), ) # self.linear = torch.nn.Conv2d(in_channels=hidden_channels, # out_channels=hidden_channels, # kernel_size=(1,1)) self.linears = torch.nn.ModuleList( torch.nn.Conv2d( in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=(1, 1), ) for _ in range(num_hidden_layers - 1) ) self.linear_out = torch.nn.Conv2d( in_channels=hidden_channels, out_channels=input_channels * hidden_channels, kernel_size=(1, 1), ) def extra_repr(self): return ( "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" "".format( self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers, ) ) def forward(self, *args): # z: torch.Size([64, 207, 32]) z = args[0] if len(args) == 1 else args[1] z = self.start_conv(z.transpose(1, 2).unsqueeze(-1)) z = z.relu() for linear in self.linears: z = linear(z) z = z.relu() z = ( self.linear_out(z) .squeeze() .transpose(1, 2) .view( *z.transpose(1, 2).shape[:-2], self.hidden_channels, self.input_channels ) ) z = z.tanh() return z class VectorField_g(torch.nn.Module): def __init__( self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers, num_nodes, cheb_k, embed_dim, g_type, ): super(VectorField_g, self).__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.hidden_hidden_channels = hidden_hidden_channels self.num_hidden_layers = num_hidden_layers self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels) # self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels) # for _ in range(num_hidden_layers - 1)) # project to (hidden_channels, input_channels) for cdeint requirement self.linear_out = torch.nn.Linear( hidden_hidden_channels, hidden_channels * input_channels ) self.g_type = g_type if self.g_type == "agc": self.node_embeddings = nn.Parameter( torch.randn(num_nodes, embed_dim), requires_grad=True ) self.cheb_k = cheb_k self.weights_pool = nn.Parameter( torch.FloatTensor( embed_dim, cheb_k, hidden_hidden_channels, hidden_hidden_channels ) ) self.bias_pool = nn.Parameter( torch.FloatTensor(embed_dim, hidden_hidden_channels) ) def extra_repr(self): return ( "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" "".format( self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers, ) ) def forward(self, *args): z = args[0] if len(args) == 1 else args[1] z = self.linear_in(z) z = z.relu() if self.g_type == "agc": z = self.agc(z) else: raise ValueError("Check g_type argument") # for linear in self.linears: # z = linear(x_gconv) # z = z.relu() # output shape (..., hidden_channels, input_channels) z = self.linear_out(z).view( *z.shape[:-1], self.hidden_channels, self.input_channels ) z = z.tanh() return z # torch.Size([64, 307, 64, 1]) def agc(self, z): """ Adaptive Graph Convolution - Node Adaptive Parameter Learning - Data Adaptive Graph Generation """ node_num = self.node_embeddings.shape[0] supports = F.softmax( F.relu( torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1)) ), dim=1, ) # laplacian=False laplacian = False if laplacian == True: # support_set = [torch.eye(node_num).to(supports.device), -supports] support_set = [supports, -torch.eye(node_num).to(supports.device)] # support_set = [torch.eye(node_num).to(supports.device), -supports] # support_set = [-supports] else: support_set = [torch.eye(node_num).to(supports.device), supports] # default cheb_k = 3 for k in range(2, self.cheb_k): support_set.append( torch.matmul(2 * supports, support_set[-1]) - support_set[-2] ) supports = torch.stack(support_set, dim=0) weights = torch.einsum( "nd,dkio->nkio", self.node_embeddings, self.weights_pool ) # N, cheb_k, dim_in, dim_out bias = torch.matmul(self.node_embeddings, self.bias_pool) # N, dim_out x_g = torch.einsum("knm,bmc->bknc", supports, z) # B, cheb_k, N, dim_in x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in z = torch.einsum("bnki,nkio->bno", x_g, weights) + bias # b, N, dim_out return z class VectorField_only_g(torch.nn.Module): def __init__( self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers, num_nodes, cheb_k, embed_dim, g_type, ): super(VectorField_only_g, self).__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.hidden_hidden_channels = hidden_hidden_channels self.num_hidden_layers = num_hidden_layers self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels) # self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels) # for _ in range(num_hidden_layers - 1)) # FIXME: self.linear_out = torch.nn.Linear( hidden_hidden_channels, input_channels * hidden_channels ) # 32,32*4 -> # 32,32,4 # self.linear_out = torch.nn.Linear(hidden_hidden_channels, hidden_channels * hidden_channels) #32,32*4 -> # 32,32,4 self.g_type = g_type if self.g_type == "agc": self.node_embeddings = nn.Parameter( torch.randn(num_nodes, embed_dim), requires_grad=True ) self.cheb_k = cheb_k self.weights_pool = nn.Parameter( torch.FloatTensor( embed_dim, cheb_k, hidden_hidden_channels, hidden_hidden_channels ) ) self.bias_pool = nn.Parameter( torch.FloatTensor(embed_dim, hidden_hidden_channels) ) def extra_repr(self): return ( "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" "".format( self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers, ) ) def forward(self, *args): z = args[0] if len(args) == 1 else args[1] z = self.linear_in(z) z = z.relu() if self.g_type == "agc": z = self.agc(z) else: raise ValueError("Check g_type argument") # for linear in self.linears: # z = linear(x_gconv) # z = z.relu() # FIXME: z = self.linear_out(z).view( *z.shape[:-1], self.hidden_channels, self.input_channels ) # z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.hidden_channels) z = z.tanh() return z # torch.Size([64, 307, 64, 1]) def agc(self, z): """ Adaptive Graph Convolution - Node Adaptive Parameter Learning - Data Adaptive Graph Generation """ node_num = self.node_embeddings.shape[0] supports = F.softmax( F.relu( torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1)) ), dim=1, ) laplacian = False if laplacian == True: # support_set = [torch.eye(node_num).to(supports.device), -supports] support_set = [supports, -torch.eye(node_num).to(supports.device)] # support_set = [torch.eye(node_num).to(supports.device), -supports] # support_set = [-supports] else: support_set = [torch.eye(node_num).to(supports.device), supports] # default cheb_k = 3 for k in range(2, self.cheb_k): support_set.append( torch.matmul(2 * supports, support_set[-1]) - support_set[-2] ) supports = torch.stack(support_set, dim=0) weights = torch.einsum( "nd,dkio->nkio", self.node_embeddings, self.weights_pool ) # N, cheb_k, dim_in, dim_out bias = torch.matmul(self.node_embeddings, self.bias_pool) # N, dim_out x_g = torch.einsum("knm,bmc->bknc", supports, z) # B, cheb_k, N, dim_in x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in z = torch.einsum("bnki,nkio->bno", x_g, weights) + bias # b, N, dim_out return z class VectorField_g_prime(torch.nn.Module): def __init__( self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers, num_nodes, cheb_k, embed_dim, g_type, ): super(VectorField_g_prime, self).__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.hidden_hidden_channels = hidden_hidden_channels self.num_hidden_layers = num_hidden_layers self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels) # self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels) # for _ in range(num_hidden_layers - 1)) self.linear_out = torch.nn.Linear( hidden_hidden_channels, input_channels * hidden_channels ) # 32,32*4 -> # 32,32,4 self.g_type = g_type if self.g_type == "agc": self.node_embeddings = nn.Parameter( torch.randn(num_nodes, embed_dim), requires_grad=True ) self.cheb_k = cheb_k self.weights_pool = nn.Parameter( torch.FloatTensor( embed_dim, cheb_k, hidden_hidden_channels, hidden_hidden_channels ) ) self.bias_pool = nn.Parameter( torch.FloatTensor(embed_dim, hidden_hidden_channels) ) def extra_repr(self): return ( "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" "".format( self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers, ) ) def forward(self, z): z = self.linear_in(z) z = z.relu() if self.g_type == "agc": z = self.agc(z) else: raise ValueError("Check g_type argument") # for linear in self.linears: # z = linear(x_gconv) # z = z.relu() z = self.linear_out(z).view( *z.shape[:-1], self.hidden_channels, self.input_channels ) z = z.tanh() return z # torch.Size([64, 307, 64, 1]) def agc(self, z): """ Adaptive Graph Convolution - Node Adaptive Parameter Learning - Data Adaptive Graph Generation """ node_num = self.node_embeddings.shape[0] supports = F.softmax( F.relu( torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1)) ), dim=1, ) support_set = [torch.eye(node_num).to(supports.device), supports] # default cheb_k = 3 for k in range(2, self.cheb_k): support_set.append( torch.matmul(2 * supports, support_set[-1]) - support_set[-2] ) supports = torch.stack(support_set, dim=0) weights = torch.einsum( "nd,dkio->nkio", self.node_embeddings, self.weights_pool ) # N, cheb_k, dim_in, dim_out bias = torch.matmul(self.node_embeddings, self.bias_pool) # N, dim_out x_g = torch.einsum("knm,bmc->bknc", supports, z) # B, cheb_k, N, dim_in x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in z = torch.einsum("bnki,nkio->bno", x_g, weights) + bias # b, N, dim_out return z