FS-TFP/federatedscope/trafficflow/model/DGCN.py

70 lines
2.7 KiB
Python

import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from collections import OrderedDict
class DGCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
super(DGCN, self).__init__()
self.cheb_k = cheb_k
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
# Initialize parameters
nn.init.xavier_uniform_(self.weights_pool)
nn.init.xavier_uniform_(self.weights)
nn.init.zeros_(self.bias_pool)
nn.init.zeros_(self.bias)
self.hyperGNN_dim = 16
self.middle_dim = 2
self.embed_dim = embed_dim
self.fc = nn.Sequential(
OrderedDict([('fc1', nn.Linear(dim_in, self.hyperGNN_dim)),
('sigmoid1', nn.Sigmoid()),
('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)),
('sigmoid2', nn.Sigmoid()),
('fc3', nn.Linear(self.middle_dim, self.embed_dim))]))
def forward(self, x, node_embeddings):
node_num = node_embeddings[0].shape[1]
supports1 = torch.eye(node_num).to(node_embeddings[0].device)
filter = self.fc(x)
nodevec = torch.tanh(torch.mul(node_embeddings[0], filter)) # [B,N,dim_in]
graph = F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1)))
supports2 = DGCN.get_laplacian(graph, supports1)
x_g1 = torch.einsum("nm,bmc->bnc", supports1, x)
x_g2 = torch.einsum("bnm,bmc->bnc", supports2, x)
x_g = torch.stack([x_g1, x_g2], dim=1)
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
bias = torch.matmul(node_embeddings[1], self.bias_pool)
x_g = x_g.permute(0, 2, 1, 3)
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias
return x_gconv
@staticmethod
def get_laplacian(graph, I, normalize=True):
"""
return the laplacian of the graph.
:param graph: the graph structure without self loop, [N, N].
:param normalize: whether to used the normalized laplacian.
:return: graph laplacian.
"""
if normalize:
epsilon = 1e-6
D = torch.diag_embed((torch.sum(graph, dim=-1) + epsilon) ** (-1 / 2))
L = torch.matmul(torch.matmul(D, graph), D)
else:
graph = graph + I
D = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.matmul(torch.matmul(D, graph), D)
return L