import torch import torch.nn.functional as F import torch.nn as nn class AVWGCN(nn.Module): def __init__(self, dim_in, dim_out, cheb_k, embed_dim): super(AVWGCN, self).__init__() self.cheb_k = cheb_k self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) def forward(self, x, node_embeddings): #x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N] #output shape [B, N, C] node_num = node_embeddings.shape[0] supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1) support_set = [torch.eye(node_num).to(supports.device), supports] #default cheb_k = 3 for k in range(2, self.cheb_k): support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2]) supports = torch.stack(support_set, dim=0) weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out bias = torch.matmul(node_embeddings, self.bias_pool) #N, dim_out x_g = torch.einsum("knm,bmc->bknc", supports, x) #B, cheb_k, N, dim_in x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out return x_gconv