TrafficWheel/model/TWDGCN/DGCN.py

107 lines
3.5 KiB
Python
Executable File

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
from model.TWDGCN import ConnectionMatrix
class DGCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
super(DGCN, self).__init__()
self.cheb_k = cheb_k
self.embed_dim = embed_dim
# Initialize parameters
self.weights_pool = nn.Parameter(
torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)
)
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
# Hyperparameters
self.hyperGNN_dim = 16
self.middle_dim = 2
# Fully connected layers
self.fc = nn.Sequential(
OrderedDict(
[
("fc1", nn.Linear(dim_in, self.hyperGNN_dim)),
("sigmoid1", nn.Sigmoid()),
("fc2", nn.Linear(self.hyperGNN_dim, self.middle_dim)),
("sigmoid2", nn.Sigmoid()),
("fc3", nn.Linear(self.middle_dim, self.embed_dim)),
]
)
)
self.conn_matrix = ConnectionMatrix.ConnectionMatrix()
def forward(self, x, node_embeddings, connMtx):
"""
Forward pass for the DGCN model.
Parameters:
- x: Input tensor of shape [B, N, C]
- node_embeddings: Node embeddings tensor of shape [N, D]
- connMtx: Connectivity matrix
Returns:
- x_gconv: Output tensor of shape [B, N, dim_out]
"""
node_num = node_embeddings[0].shape[1]
supports1 = torch.eye(node_num).to(node_embeddings[0].device) # Identity matrix
# Apply fully connected layers
filter = self.fc(x)
nodevec = torch.tanh(
torch.mul(node_embeddings[0], filter)
) # Element-wise multiplication
# Compute Laplacian
supports2 = self.get_laplacian(
F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supports1
)
supports3 = connMtx * supports2 # Apply dynamic heterogeneity matrix mask
# Graph convolution
x_g1 = torch.einsum("nm,bmc->bnc", supports1, x)
x_g2 = torch.einsum("bnm,bmc->bnc", supports3, x)
x_g = torch.stack([x_g1, x_g2], dim=1)
# Apply graph convolution weights and biases
weights = torch.einsum("nd,dkio->nkio", node_embeddings[1], self.weights_pool)
bias = torch.matmul(node_embeddings[1], self.bias_pool)
x_g = x_g.permute(0, 2, 1, 3) # Rearrange dimensions
x_gconv = (
torch.einsum("bnki,nkio->bno", x_g, weights) + bias
) # Graph convolution operation
return x_gconv
@staticmethod
def get_laplacian(graph, I, normalize=True):
"""
Compute the Laplacian of the graph.
Parameters:
- graph: Adjacency matrix of the graph, [N, N]
- I: Identity matrix
- normalize: Whether to use the normalized Laplacian
Returns:
- L: Graph Laplacian
"""
if normalize:
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
else:
graph = graph + I
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
return L