import torch import torch.nn as nn import torch.nn.functional as F class FinalTanh_f(nn.Module): def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers): super(FinalTanh_f, self).__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.hidden_hidden_channels = hidden_hidden_channels self.num_hidden_layers = num_hidden_layers self.linear_in = nn.Linear(hidden_channels, hidden_hidden_channels) self.linears = nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels) for _ in range(num_hidden_layers - 1)) self.linear_out = nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4 def extra_repr(self): return "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" \ "".format(self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers) def forward(self, *args): z = args[0] if len(args) == 1 else args[1] z = self.linear_in(z) z = z.relu() for linear in self.linears: z = linear(z) z = z.relu() # z: torch.Size([64, 207, 32]) # self.linear_out(z): torch.Size([64, 207, 64]) z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.input_channels) z = z.tanh() return z class FinalTanh_f_prime(nn.Module): def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers): super(FinalTanh_f_prime, self).__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.hidden_hidden_channels = hidden_hidden_channels self.num_hidden_layers = num_hidden_layers self.linear_in = nn.Linear(hidden_channels, hidden_hidden_channels) self.linears = nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels) for _ in range(num_hidden_layers - 1)) # self.linear_out = nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4 self.linear_out = nn.Linear(hidden_hidden_channels, hidden_channels * hidden_channels) #32,32*4 -> # 32,32,4 def extra_repr(self): return "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" \ "".format(self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers) def forward(self, *args): z = args[0] if len(args) == 1 else args[1] z = self.linear_in(z) z = z.relu() for linear in self.linears: z = linear(z) z = z.relu() # z: torch.Size([64, 207, 32]) # self.linear_out(z): torch.Size([64, 207, 64]) # z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.input_channels) z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.hidden_channels) z = z.tanh() return z class FinalTanh_f2(torch.nn.Module): def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers): super(FinalTanh_f2, self).__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.hidden_hidden_channels = hidden_hidden_channels self.num_hidden_layers = num_hidden_layers # self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels) # self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels) # for _ in range(num_hidden_layers - 1)) # self.linear_out = torch.nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4 self.start_conv = torch.nn.Conv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=(1,1)) # self.linear = torch.nn.Conv2d(in_channels=hidden_channels, # out_channels=hidden_channels, # kernel_size=(1,1)) self.linears = torch.nn.ModuleList(torch.nn.Conv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=(1,1)) for _ in range(num_hidden_layers - 1)) self.linear_out = torch.nn.Conv2d(in_channels=hidden_channels, out_channels=input_channels*hidden_channels, kernel_size=(1,1)) def extra_repr(self): return "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" \ "".format(self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers) def forward(self, *args): # z: torch.Size([64, 207, 32]) z = args[0] if len(args) == 1 else args[1] z = self.start_conv(z.transpose(1,2).unsqueeze(-1)) z = z.relu() for linear in self.linears: z = linear(z) z = z.relu() z = self.linear_out(z).squeeze().transpose(1,2).view(*z.transpose(1,2).shape[:-2], self.hidden_channels, self.input_channels) z = z.tanh() return z class VectorField_g(torch.nn.Module): def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers, num_nodes, cheb_k, embed_dim, g_type): super(VectorField_g, self).__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.hidden_hidden_channels = hidden_hidden_channels self.num_hidden_layers = num_hidden_layers self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels) # self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels) # for _ in range(num_hidden_layers - 1)) # project to (hidden_channels, input_channels) for cdeint requirement self.linear_out = torch.nn.Linear(hidden_hidden_channels, hidden_channels * input_channels) self.g_type = g_type if self.g_type == 'agc': self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) self.cheb_k = cheb_k self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, hidden_hidden_channels, hidden_hidden_channels)) self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, hidden_hidden_channels)) def extra_repr(self): return "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" \ "".format(self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers) def forward(self, *args): z = args[0] if len(args) == 1 else args[1] z = self.linear_in(z) z = z.relu() if self.g_type == 'agc': z = self.agc(z) else: raise ValueError('Check g_type argument') # for linear in self.linears: # z = linear(x_gconv) # z = z.relu() # output shape (..., hidden_channels, input_channels) z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.input_channels) z = z.tanh() return z #torch.Size([64, 307, 64, 1]) def agc(self, z): """ Adaptive Graph Convolution - Node Adaptive Parameter Learning - Data Adaptive Graph Generation """ node_num = self.node_embeddings.shape[0] supports = F.softmax(F.relu(torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1))), dim=1) # laplacian=False laplacian=False if laplacian == True: # support_set = [torch.eye(node_num).to(supports.device), -supports] support_set = [supports, -torch.eye(node_num).to(supports.device)] # support_set = [torch.eye(node_num).to(supports.device), -supports] # support_set = [-supports] else: support_set = [torch.eye(node_num).to(supports.device), supports] #default cheb_k = 3 for k in range(2, self.cheb_k): support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2]) supports = torch.stack(support_set, dim=0) weights = torch.einsum('nd,dkio->nkio', self.node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out bias = torch.matmul(self.node_embeddings, self.bias_pool) #N, dim_out x_g = torch.einsum("knm,bmc->bknc", supports, z) #B, cheb_k, N, dim_in x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in z = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out return z class VectorField_only_g(torch.nn.Module): def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers, num_nodes, cheb_k, embed_dim, g_type): super(VectorField_only_g, self).__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.hidden_hidden_channels = hidden_hidden_channels self.num_hidden_layers = num_hidden_layers self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels) # self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels) # for _ in range(num_hidden_layers - 1)) #FIXME: self.linear_out = torch.nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4 # self.linear_out = torch.nn.Linear(hidden_hidden_channels, hidden_channels * hidden_channels) #32,32*4 -> # 32,32,4 self.g_type = g_type if self.g_type == 'agc': self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) self.cheb_k = cheb_k self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, hidden_hidden_channels, hidden_hidden_channels)) self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, hidden_hidden_channels)) def extra_repr(self): return "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" \ "".format(self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers) def forward(self, *args): z = args[0] if len(args) == 1 else args[1] z = self.linear_in(z) z = z.relu() if self.g_type == 'agc': z = self.agc(z) else: raise ValueError('Check g_type argument') # for linear in self.linears: # z = linear(x_gconv) # z = z.relu() #FIXME: z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.input_channels) # z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.hidden_channels) z = z.tanh() return z #torch.Size([64, 307, 64, 1]) def agc(self, z): """ Adaptive Graph Convolution - Node Adaptive Parameter Learning - Data Adaptive Graph Generation """ node_num = self.node_embeddings.shape[0] supports = F.softmax(F.relu(torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1))), dim=1) laplacian=False if laplacian == True: # support_set = [torch.eye(node_num).to(supports.device), -supports] support_set = [supports, -torch.eye(node_num).to(supports.device)] # support_set = [torch.eye(node_num).to(supports.device), -supports] # support_set = [-supports] else: support_set = [torch.eye(node_num).to(supports.device), supports] #default cheb_k = 3 for k in range(2, self.cheb_k): support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2]) supports = torch.stack(support_set, dim=0) weights = torch.einsum('nd,dkio->nkio', self.node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out bias = torch.matmul(self.node_embeddings, self.bias_pool) #N, dim_out x_g = torch.einsum("knm,bmc->bknc", supports, z) #B, cheb_k, N, dim_in x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in z = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out return z class VectorField_g_prime(torch.nn.Module): def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers, num_nodes, cheb_k, embed_dim, g_type): super(VectorField_g_prime, self).__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.hidden_hidden_channels = hidden_hidden_channels self.num_hidden_layers = num_hidden_layers self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels) # self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels) # for _ in range(num_hidden_layers - 1)) self.linear_out = torch.nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4 self.g_type = g_type if self.g_type == 'agc': self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) self.cheb_k = cheb_k self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, hidden_hidden_channels, hidden_hidden_channels)) self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, hidden_hidden_channels)) def extra_repr(self): return "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" \ "".format(self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers) def forward(self, z): z = self.linear_in(z) z = z.relu() if self.g_type == 'agc': z = self.agc(z) else: raise ValueError('Check g_type argument') # for linear in self.linears: # z = linear(x_gconv) # z = z.relu() z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.input_channels) z = z.tanh() return z #torch.Size([64, 307, 64, 1]) def agc(self, z): """ Adaptive Graph Convolution - Node Adaptive Parameter Learning - Data Adaptive Graph Generation """ node_num = self.node_embeddings.shape[0] supports = F.softmax(F.relu(torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1))), dim=1) support_set = [torch.eye(node_num).to(supports.device), supports] #default cheb_k = 3 for k in range(2, self.cheb_k): support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2]) supports = torch.stack(support_set, dim=0) weights = torch.einsum('nd,dkio->nkio', self.node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out bias = torch.matmul(self.node_embeddings, self.bias_pool) #N, dim_out x_g = torch.einsum("knm,bmc->bknc", supports, z) #B, cheb_k, N, dim_in x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in z = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out return z