TrafficWheel/model/AGCRN/AGCN.py

26 lines
1.4 KiB
Python

import torch
import torch.nn.functional as F
import torch.nn as nn
class AVWGCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
super(AVWGCN, self).__init__()
self.cheb_k = cheb_k
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
def forward(self, x, node_embeddings):
#x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N]
#output shape [B, N, C]
node_num = node_embeddings.shape[0]
supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1)
support_set = [torch.eye(node_num).to(supports.device), supports]
#default cheb_k = 3
for k in range(2, self.cheb_k):
support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2])
supports = torch.stack(support_set, dim=0)
weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out
bias = torch.matmul(node_embeddings, self.bias_pool) #N, dim_out
x_g = torch.einsum("knm,bmc->bknc", supports, x) #B, cheb_k, N, dim_in
x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out
return x_gconv