126 lines
4.5 KiB
Python
126 lines
4.5 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch.nn import Linear, Sequential
|
|
from torch_geometric.data import Data
|
|
from torch_geometric.data.batch import Batch
|
|
from torch_geometric.nn.glob import global_add_pool, global_mean_pool, \
|
|
global_max_pool
|
|
|
|
from federatedscope.gfl.model.gcn import GCN_Net
|
|
from federatedscope.gfl.model.sage import SAGE_Net
|
|
from federatedscope.gfl.model.gat import GAT_Net
|
|
from federatedscope.gfl.model.gin import GIN_Net
|
|
from federatedscope.gfl.model.gpr import GPR_Net
|
|
|
|
EMD_DIM = 200
|
|
|
|
|
|
class AtomEncoder(torch.nn.Module):
|
|
def __init__(self, in_channels, hidden):
|
|
super(AtomEncoder, self).__init__()
|
|
self.atom_embedding_list = torch.nn.ModuleList()
|
|
for i in range(in_channels):
|
|
emb = torch.nn.Embedding(EMD_DIM, hidden)
|
|
torch.nn.init.xavier_uniform_(emb.weight.data)
|
|
self.atom_embedding_list.append(emb)
|
|
|
|
def forward(self, x):
|
|
x_embedding = 0
|
|
for i in range(x.shape[1]):
|
|
x_embedding += self.atom_embedding_list[i](x[:, i])
|
|
return x_embedding
|
|
|
|
|
|
class GNN_Net_Graph(torch.nn.Module):
|
|
r"""GNN model with pre-linear layer, pooling layer
|
|
and output layer for graph classification tasks.
|
|
|
|
Arguments:
|
|
in_channels (int): input channels.
|
|
out_channels (int): output channels.
|
|
hidden (int): hidden dim for all modules.
|
|
max_depth (int): number of layers for gnn.
|
|
dropout (float): dropout probability.
|
|
gnn (str): name of gnn type, use ("gcn" or "gin").
|
|
pooling (str): pooling method, use ("add", "mean" or "max").
|
|
"""
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
hidden=64,
|
|
max_depth=2,
|
|
dropout=.0,
|
|
gnn='gcn',
|
|
pooling='add'):
|
|
super(GNN_Net_Graph, self).__init__()
|
|
self.dropout = dropout
|
|
# Embedding (pre) layer
|
|
self.encoder_atom = AtomEncoder(in_channels, hidden)
|
|
self.encoder = Linear(in_channels, hidden)
|
|
# GNN layer
|
|
if gnn == 'gcn':
|
|
self.gnn = GCN_Net(in_channels=hidden,
|
|
out_channels=hidden,
|
|
hidden=hidden,
|
|
max_depth=max_depth,
|
|
dropout=dropout)
|
|
elif gnn == 'sage':
|
|
self.gnn = SAGE_Net(in_channels=hidden,
|
|
out_channels=hidden,
|
|
hidden=hidden,
|
|
max_depth=max_depth,
|
|
dropout=dropout)
|
|
elif gnn == 'gat':
|
|
self.gnn = GAT_Net(in_channels=hidden,
|
|
out_channels=hidden,
|
|
hidden=hidden,
|
|
max_depth=max_depth,
|
|
dropout=dropout)
|
|
elif gnn == 'gin':
|
|
self.gnn = GIN_Net(in_channels=hidden,
|
|
out_channels=hidden,
|
|
hidden=hidden,
|
|
max_depth=max_depth,
|
|
dropout=dropout)
|
|
elif gnn == 'gpr':
|
|
self.gnn = GPR_Net(in_channels=hidden,
|
|
out_channels=hidden,
|
|
hidden=hidden,
|
|
K=max_depth,
|
|
dropout=dropout)
|
|
else:
|
|
raise ValueError(f'Unsupported gnn type: {gnn}.')
|
|
|
|
# Pooling layer
|
|
if pooling == 'add':
|
|
self.pooling = global_add_pool
|
|
elif pooling == 'mean':
|
|
self.pooling = global_mean_pool
|
|
elif pooling == 'max':
|
|
self.pooling = global_max_pool
|
|
else:
|
|
raise ValueError(f'Unsupported pooling type: {pooling}.')
|
|
# Output layer
|
|
self.linear = Sequential(Linear(hidden, hidden), torch.nn.ReLU())
|
|
self.clf = Linear(hidden, out_channels)
|
|
|
|
def forward(self, data):
|
|
if isinstance(data, Batch):
|
|
x, edge_index, batch = data.x, data.edge_index, data.batch
|
|
elif isinstance(data, tuple):
|
|
x, edge_index, batch = data
|
|
else:
|
|
raise TypeError('Unsupported data type!')
|
|
|
|
if x.dtype == torch.int64:
|
|
x = self.encoder_atom(x)
|
|
else:
|
|
x = self.encoder(x)
|
|
|
|
x = self.gnn((x, edge_index))
|
|
x = self.pooling(x, batch)
|
|
x = self.linear(x)
|
|
x = F.dropout(x, self.dropout, training=self.training)
|
|
x = self.clf(x)
|
|
return x
|