FS-TFP/federatedscope/gfl/model/model_builder.py

82 lines
3.7 KiB
Python

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
from federatedscope.gfl.model.gcn import GCN_Net
from federatedscope.gfl.model.sage import SAGE_Net
from federatedscope.gfl.model.gat import GAT_Net
from federatedscope.gfl.model.gin import GIN_Net
from federatedscope.gfl.model.gpr import GPR_Net
from federatedscope.gfl.model.link_level import GNN_Net_Link
from federatedscope.gfl.model.graph_level import GNN_Net_Graph
from federatedscope.gfl.model.mpnn import MPNNs2s
def get_gnn(model_config, input_shape):
x_shape, num_label, num_edge_features = input_shape
if not num_label:
num_label = 0
if model_config.task.startswith('node'):
if model_config.type == 'gcn':
# assume `data` is a dict where key is the client index,
# and value is a PyG object
model = GCN_Net(x_shape[-1],
model_config.out_channels,
hidden=model_config.hidden,
max_depth=model_config.layer,
dropout=model_config.dropout)
elif model_config.type == 'sage':
model = SAGE_Net(x_shape[-1],
model_config.out_channels,
hidden=model_config.hidden,
max_depth=model_config.layer,
dropout=model_config.dropout)
elif model_config.type == 'gat':
model = GAT_Net(x_shape[-1],
model_config.out_channels,
hidden=model_config.hidden,
max_depth=model_config.layer,
dropout=model_config.dropout)
elif model_config.type == 'gin':
model = GIN_Net(x_shape[-1],
model_config.out_channels,
hidden=model_config.hidden,
max_depth=model_config.layer,
dropout=model_config.dropout)
elif model_config.type == 'gpr':
model = GPR_Net(x_shape[-1],
model_config.out_channels,
hidden=model_config.hidden,
K=model_config.layer,
dropout=model_config.dropout)
else:
raise ValueError('not recognized gnn model {}'.format(
model_config.type))
elif model_config.task.startswith('link'):
model = GNN_Net_Link(x_shape[-1],
model_config.out_channels,
hidden=model_config.hidden,
max_depth=model_config.layer,
dropout=model_config.dropout,
gnn=model_config.type)
elif model_config.task.startswith('graph'):
if model_config.type == 'mpnn':
model = MPNNs2s(in_channels=x_shape[-1],
out_channels=model_config.out_channels,
num_nn=num_edge_features,
hidden=model_config.hidden)
else:
model = GNN_Net_Graph(x_shape[-1],
max(model_config.out_channels, num_label),
hidden=model_config.hidden,
max_depth=model_config.layer,
dropout=model_config.dropout,
gnn=model_config.type,
pooling=model_config.graph_pooling)
else:
raise ValueError('not recognized data task {}'.format(
model_config.task))
return model