import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from collections import OrderedDict from model.TWDGCN import ConnectionMatrix class DGCN(nn.Module): def __init__(self, dim_in, dim_out, cheb_k, embed_dim): super(DGCN, self).__init__() self.cheb_k = cheb_k self.embed_dim = embed_dim # Initialize parameters self.weights_pool = nn.Parameter( torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out) ) self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out)) self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) self.bias = nn.Parameter(torch.FloatTensor(dim_out)) # Hyperparameters self.hyperGNN_dim = 16 self.middle_dim = 2 # Fully connected layers self.fc = nn.Sequential( OrderedDict( [ ("fc1", nn.Linear(dim_in, self.hyperGNN_dim)), ("sigmoid1", nn.Sigmoid()), ("fc2", nn.Linear(self.hyperGNN_dim, self.middle_dim)), ("sigmoid2", nn.Sigmoid()), ("fc3", nn.Linear(self.middle_dim, self.embed_dim)), ] ) ) self.conn_matrix = ConnectionMatrix.ConnectionMatrix() def forward(self, x, node_embeddings, connMtx): """ Forward pass for the DGCN model. Parameters: - x: Input tensor of shape [B, N, C] - node_embeddings: Node embeddings tensor of shape [N, D] - connMtx: Connectivity matrix Returns: - x_gconv: Output tensor of shape [B, N, dim_out] """ node_num = node_embeddings[0].shape[1] supports1 = torch.eye(node_num).to(node_embeddings[0].device) # Identity matrix # Apply fully connected layers filter = self.fc(x) nodevec = torch.tanh( torch.mul(node_embeddings[0], filter) ) # Element-wise multiplication # Compute Laplacian supports2 = self.get_laplacian( F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supports1 ) supports3 = connMtx * supports2 # Apply dynamic heterogeneity matrix mask # Graph convolution x_g1 = torch.einsum("nm,bmc->bnc", supports1, x) x_g2 = torch.einsum("bnm,bmc->bnc", supports3, x) x_g = torch.stack([x_g1, x_g2], dim=1) # Apply graph convolution weights and biases weights = torch.einsum("nd,dkio->nkio", node_embeddings[1], self.weights_pool) bias = torch.matmul(node_embeddings[1], self.bias_pool) x_g = x_g.permute(0, 2, 1, 3) # Rearrange dimensions x_gconv = ( torch.einsum("bnki,nkio->bno", x_g, weights) + bias ) # Graph convolution operation return x_gconv @staticmethod def get_laplacian(graph, I, normalize=True): """ Compute the Laplacian of the graph. Parameters: - graph: Adjacency matrix of the graph, [N, N] - I: Identity matrix - normalize: Whether to use the normalized Laplacian Returns: - L: Graph Laplacian """ if normalize: D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2)) L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt) else: graph = graph + I D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2)) L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt) return L