from __future__ import absolute_import from __future__ import print_function from __future__ import division from federatedscope.gfl.model.gcn import GCN_Net from federatedscope.gfl.model.sage import SAGE_Net from federatedscope.gfl.model.gat import GAT_Net from federatedscope.gfl.model.gin import GIN_Net from federatedscope.gfl.model.gpr import GPR_Net from federatedscope.gfl.model.link_level import GNN_Net_Link from federatedscope.gfl.model.graph_level import GNN_Net_Graph from federatedscope.gfl.model.mpnn import MPNNs2s def get_gnn(model_config, input_shape): x_shape, num_label, num_edge_features = input_shape if not num_label: num_label = 0 if model_config.task.startswith('node'): if model_config.type == 'gcn': # assume `data` is a dict where key is the client index, # and value is a PyG object model = GCN_Net(x_shape[-1], model_config.out_channels, hidden=model_config.hidden, max_depth=model_config.layer, dropout=model_config.dropout) elif model_config.type == 'sage': model = SAGE_Net(x_shape[-1], model_config.out_channels, hidden=model_config.hidden, max_depth=model_config.layer, dropout=model_config.dropout) elif model_config.type == 'gat': model = GAT_Net(x_shape[-1], model_config.out_channels, hidden=model_config.hidden, max_depth=model_config.layer, dropout=model_config.dropout) elif model_config.type == 'gin': model = GIN_Net(x_shape[-1], model_config.out_channels, hidden=model_config.hidden, max_depth=model_config.layer, dropout=model_config.dropout) elif model_config.type == 'gpr': model = GPR_Net(x_shape[-1], model_config.out_channels, hidden=model_config.hidden, K=model_config.layer, dropout=model_config.dropout) else: raise ValueError('not recognized gnn model {}'.format( model_config.type)) elif model_config.task.startswith('link'): model = GNN_Net_Link(x_shape[-1], model_config.out_channels, hidden=model_config.hidden, max_depth=model_config.layer, dropout=model_config.dropout, gnn=model_config.type) elif model_config.task.startswith('graph'): if model_config.type == 'mpnn': model = MPNNs2s(in_channels=x_shape[-1], out_channels=model_config.out_channels, num_nn=num_edge_features, hidden=model_config.hidden) else: model = GNN_Net_Graph(x_shape[-1], max(model_config.out_channels, num_label), hidden=model_config.hidden, max_depth=model_config.layer, dropout=model_config.dropout, gnn=model_config.type, pooling=model_config.graph_pooling) else: raise ValueError('not recognized data task {}'.format( model_config.task)) return model