26 lines
1.4 KiB
Python
26 lines
1.4 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
|
|
class AVWGCN(nn.Module):
|
|
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
|
|
super(AVWGCN, self).__init__()
|
|
self.cheb_k = cheb_k
|
|
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
|
|
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
|
|
def forward(self, x, node_embeddings):
|
|
#x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N]
|
|
#output shape [B, N, C]
|
|
node_num = node_embeddings.shape[0]
|
|
supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1)
|
|
support_set = [torch.eye(node_num).to(supports.device), supports]
|
|
#default cheb_k = 3
|
|
for k in range(2, self.cheb_k):
|
|
support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2])
|
|
supports = torch.stack(support_set, dim=0)
|
|
weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out
|
|
bias = torch.matmul(node_embeddings, self.bias_pool) #N, dim_out
|
|
x_g = torch.einsum("knm,bmc->bknc", supports, x) #B, cheb_k, N, dim_in
|
|
x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in
|
|
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out
|
|
return x_gconv |