95 lines
3.3 KiB
Python
95 lines
3.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from collections import OrderedDict
|
|
from model.TWDGCN import ConnectionMatrix
|
|
|
|
|
|
class DGCN(nn.Module):
|
|
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
|
|
super(DGCN, self).__init__()
|
|
self.cheb_k = cheb_k
|
|
self.embed_dim = embed_dim
|
|
|
|
# Initialize parameters
|
|
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
|
|
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
|
|
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
|
|
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
|
|
|
|
# Hyperparameters
|
|
self.hyperGNN_dim = 16
|
|
self.middle_dim = 2
|
|
|
|
# Fully connected layers
|
|
self.fc = nn.Sequential(OrderedDict([
|
|
('fc1', nn.Linear(dim_in, self.hyperGNN_dim)),
|
|
('sigmoid1', nn.Sigmoid()),
|
|
('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)),
|
|
('sigmoid2', nn.Sigmoid()),
|
|
('fc3', nn.Linear(self.middle_dim, self.embed_dim))
|
|
]))
|
|
|
|
self.conn_matrix = ConnectionMatrix.ConnectionMatrix()
|
|
|
|
def forward(self, x, node_embeddings, connMtx):
|
|
"""
|
|
Forward pass for the DGCN model.
|
|
|
|
Parameters:
|
|
- x: Input tensor of shape [B, N, C]
|
|
- node_embeddings: Node embeddings tensor of shape [N, D]
|
|
- connMtx: Connectivity matrix
|
|
|
|
Returns:
|
|
- x_gconv: Output tensor of shape [B, N, dim_out]
|
|
"""
|
|
|
|
node_num = node_embeddings[0].shape[1]
|
|
supports1 = torch.eye(node_num).to(node_embeddings[0].device) # Identity matrix
|
|
|
|
# Apply fully connected layers
|
|
filter = self.fc(x)
|
|
nodevec = torch.tanh(torch.mul(node_embeddings[0], filter)) # Element-wise multiplication
|
|
|
|
# Compute Laplacian
|
|
supports2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supports1)
|
|
supports3 = connMtx * supports2 # Apply dynamic heterogeneity matrix mask
|
|
|
|
# Graph convolution
|
|
x_g1 = torch.einsum("nm,bmc->bnc", supports1, x)
|
|
x_g2 = torch.einsum("bnm,bmc->bnc", supports3, x)
|
|
x_g = torch.stack([x_g1, x_g2], dim=1)
|
|
|
|
# Apply graph convolution weights and biases
|
|
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
|
|
bias = torch.matmul(node_embeddings[1], self.bias_pool)
|
|
|
|
x_g = x_g.permute(0, 2, 1, 3) # Rearrange dimensions
|
|
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias # Graph convolution operation
|
|
|
|
return x_gconv
|
|
|
|
@staticmethod
|
|
def get_laplacian(graph, I, normalize=True):
|
|
"""
|
|
Compute the Laplacian of the graph.
|
|
|
|
Parameters:
|
|
- graph: Adjacency matrix of the graph, [N, N]
|
|
- I: Identity matrix
|
|
- normalize: Whether to use the normalized Laplacian
|
|
|
|
Returns:
|
|
- L: Graph Laplacian
|
|
"""
|
|
if normalize:
|
|
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
|
|
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
|
|
else:
|
|
graph = graph + I
|
|
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
|
|
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
|
|
return L
|