89 lines
3.2 KiB
Python
89 lines
3.2 KiB
Python
import torch
|
|
from torch_geometric.data import Data
|
|
|
|
from federatedscope.core.mlp import MLP
|
|
from federatedscope.gfl.model.gcn import GCN_Net
|
|
from federatedscope.gfl.model.sage import SAGE_Net
|
|
from federatedscope.gfl.model.gat import GAT_Net
|
|
from federatedscope.gfl.model.gin import GIN_Net
|
|
from federatedscope.gfl.model.gpr import GPR_Net
|
|
|
|
|
|
class GNN_Net_Link(torch.nn.Module):
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
hidden=64,
|
|
max_depth=2,
|
|
dropout=.0,
|
|
gnn='gcn',
|
|
layers=2):
|
|
r"""GNN model with LinkPredictor for link prediction tasks.
|
|
|
|
Arguments:
|
|
in_channels (int): input channels.
|
|
out_channels (int): output channels.
|
|
hidden (int): hidden dim for all modules.
|
|
max_depth (int): number of layers for gnn.
|
|
dropout (float): dropout probability.
|
|
gnn (str): name of gnn type, use ("gcn" or "gin").
|
|
layers (int): number of layers for LinkPredictor.
|
|
|
|
"""
|
|
super(GNN_Net_Link, self).__init__()
|
|
self.dropout = dropout
|
|
|
|
# GNN layer
|
|
if gnn == 'gcn':
|
|
self.gnn = GCN_Net(in_channels=in_channels,
|
|
out_channels=hidden,
|
|
hidden=hidden,
|
|
max_depth=max_depth,
|
|
dropout=dropout)
|
|
elif gnn == 'sage':
|
|
self.gnn = SAGE_Net(in_channels=in_channels,
|
|
out_channels=hidden,
|
|
hidden=hidden,
|
|
max_depth=max_depth,
|
|
dropout=dropout)
|
|
elif gnn == 'gat':
|
|
self.gnn = GAT_Net(in_channels=in_channels,
|
|
out_channels=hidden,
|
|
hidden=hidden,
|
|
max_depth=max_depth,
|
|
dropout=dropout)
|
|
elif gnn == 'gin':
|
|
self.gnn = GIN_Net(in_channels=in_channels,
|
|
out_channels=hidden,
|
|
hidden=hidden,
|
|
max_depth=max_depth,
|
|
dropout=dropout)
|
|
elif gnn == 'gpr':
|
|
self.gnn = GPR_Net(in_channels=in_channels,
|
|
out_channels=hidden,
|
|
hidden=hidden,
|
|
K=max_depth,
|
|
dropout=dropout)
|
|
else:
|
|
raise ValueError(f'Unsupported gnn type: {gnn}.')
|
|
|
|
dim_list = [hidden for _ in range(layers)]
|
|
self.output = MLP([hidden] + dim_list + [out_channels],
|
|
batch_norm=True)
|
|
|
|
def forward(self, data):
|
|
if isinstance(data, Data):
|
|
x, edge_index = data.x, data.edge_index
|
|
elif isinstance(data, tuple):
|
|
x, edge_index = data
|
|
else:
|
|
raise TypeError('Unsupported data type!')
|
|
|
|
x = self.gnn((x, edge_index))
|
|
return x
|
|
|
|
def link_predictor(self, x, edge_index):
|
|
x = x[edge_index[0]] * x[edge_index[1]]
|
|
x = self.output(x)
|
|
return x
|