222 lines
8.8 KiB
Python
222 lines
8.8 KiB
Python
import logging
|
|
import numpy as np
|
|
import federatedscope.register as register
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Modifications:
|
|
# 1. Do a my_gcn demo Line 75
|
|
# (2024-9-1, czzhangheng)
|
|
|
|
try:
|
|
from federatedscope.contrib.model import *
|
|
except ImportError as error:
|
|
logger.warning(
|
|
f'{error} in `federatedscope.contrib.model`, some modules are not '
|
|
f'available.')
|
|
|
|
|
|
def get_shape_from_data(data, model_config, backend='torch'):
|
|
"""
|
|
Extract the input shape from the given data, which can be used to build \
|
|
the data. Users can also use `data.input_shape` to specify the shape.
|
|
|
|
Arguments:
|
|
data (`ClientData`): the data used for local training or evaluation \
|
|
|
|
Returns:
|
|
shape (tuple): the input shape
|
|
"""
|
|
# Handle some special cases
|
|
if model_config.type.lower() in ['vmfnet', 'hmfnet']:
|
|
return data['train'].n_col if model_config.type.lower(
|
|
) == 'vmfnet' else data['train'].n_row
|
|
elif model_config.type.lower() in [
|
|
'gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn'
|
|
] or model_config.type.startswith('gnn_'):
|
|
num_label = data['num_label'] if 'num_label' in data else None
|
|
num_edge_features = data['data'][
|
|
'num_edge_features'] if model_config.type == 'mpnn' else None
|
|
if model_config.task.startswith('graph'):
|
|
# graph-level task
|
|
data_representative = next(iter(data['train']))
|
|
return data_representative.x.shape, num_label, num_edge_features
|
|
else:
|
|
# node/link-level task
|
|
return data['data'].x.shape, num_label, num_edge_features
|
|
elif model_config.type.lower() in ['atc_model']:
|
|
return None
|
|
|
|
if isinstance(data, dict):
|
|
keys = list(data.keys())
|
|
if 'test' in keys:
|
|
key_representative = 'test'
|
|
elif 'val' in keys:
|
|
key_representative = 'val'
|
|
elif 'train' in keys:
|
|
key_representative = 'train'
|
|
elif 'data' in keys:
|
|
key_representative = 'data'
|
|
else:
|
|
key_representative = keys[0]
|
|
logger.warning(f'We chose the key {key_representative} as the '
|
|
f'representative key to extract data shape.')
|
|
data_representative = data[key_representative]
|
|
else:
|
|
# Handle the data with non-dict format
|
|
data_representative = data
|
|
|
|
if isinstance(data_representative, dict):
|
|
if 'x' in data_representative:
|
|
shape = np.asarray(data_representative['x']).shape
|
|
if len(shape) == 1: # (batch, ) = (batch, 1)
|
|
return 1
|
|
else:
|
|
return shape
|
|
elif backend == 'torch':
|
|
import torch
|
|
if issubclass(type(data_representative), torch.utils.data.DataLoader):
|
|
if model_config.type == 'my_gcn':
|
|
x = next(iter(data_representative))
|
|
return x.x.shape
|
|
x = next(iter(data_representative))
|
|
if isinstance(x, list):
|
|
return x[0].shape
|
|
return x.shape
|
|
else:
|
|
try:
|
|
x, _ = data_representative
|
|
if isinstance(x, list):
|
|
return x[0].shape
|
|
return x.shape
|
|
except:
|
|
raise TypeError('Unsupported data type.')
|
|
elif backend == 'tensorflow':
|
|
# TODO: Handle more tensorflow type here
|
|
shape = data_representative['x'].shape
|
|
if len(shape) == 1: # (batch, ) = (batch, 1)
|
|
return 1
|
|
else:
|
|
return shape
|
|
|
|
|
|
def get_model(model_config, local_data=None, backend='torch'):
|
|
"""
|
|
This function builds an instance of model to be trained.
|
|
|
|
Arguments:
|
|
model_config: ``cfg.model``, a submodule of ``cfg``
|
|
local_data: the model to be instantiated is responsible for the \
|
|
given data
|
|
backend: chosen from ``torch`` and ``tensorflow``
|
|
Returns:
|
|
model (``torch.Module``): the instantiated model.
|
|
|
|
Note:
|
|
The key-value pairs of built-in model and source are shown below:
|
|
=================================== ==============================
|
|
Model type Source
|
|
=================================== ==============================
|
|
``lr`` ``core.lr.LogisticRegression`` \
|
|
or ``cross_backends.LogisticRegression``
|
|
``mlp`` ``core.mlp.MLP``
|
|
``quadratic`` ``tabular.model.QuadraticModel``
|
|
``convnet2, convnet5, vgg11`` ``cv.model.get_cnn()``
|
|
``lstm`` ``nlp.model.get_rnn()``
|
|
``{}@transformers`` ``nlp.model.get_transformer()``
|
|
``gcn, sage, gpr, gat, gin, mpnn`` ``gfl.model.get_gnn()``
|
|
``vmfnet, hmfnet`` \
|
|
``mf.model.model_builder.get_mfnet()``
|
|
=================================== ==============================
|
|
"""
|
|
if model_config.type.lower() in ['xgb_tree', 'gbdt_tree', 'random_forest']:
|
|
input_shape = None
|
|
elif local_data is not None:
|
|
input_shape = get_shape_from_data(local_data, model_config, backend)
|
|
else:
|
|
input_shape = model_config.input_shape
|
|
|
|
if input_shape is None:
|
|
logger.warning('The input shape is None. Please specify the '
|
|
'`data.input_shape`(a tuple) or give the '
|
|
'representative data to `get_model` if necessary')
|
|
|
|
for func in register.model_dict.values():
|
|
model = func(model_config, input_shape)
|
|
if model is not None:
|
|
return model
|
|
|
|
if model_config.type.lower() == 'lr':
|
|
if backend == 'torch':
|
|
from federatedscope.core.lr import LogisticRegression
|
|
model = LogisticRegression(in_channels=input_shape[-1],
|
|
class_num=model_config.out_channels)
|
|
elif backend == 'tensorflow':
|
|
from federatedscope.cross_backends import LogisticRegression
|
|
model = LogisticRegression(in_channels=input_shape[-1],
|
|
class_num=1,
|
|
use_bias=model_config.use_bias)
|
|
else:
|
|
raise ValueError
|
|
|
|
elif model_config.type.lower() == 'mlp':
|
|
from federatedscope.core.mlp import MLP
|
|
model = MLP(channel_list=[input_shape[-1]] + [model_config.hidden] *
|
|
(model_config.layer - 1) + [model_config.out_channels],
|
|
dropout=model_config.dropout)
|
|
|
|
elif model_config.type.lower() == 'quadratic':
|
|
from federatedscope.tabular.model import QuadraticModel
|
|
model = QuadraticModel(input_shape[-1], 1)
|
|
|
|
elif model_config.type.lower() in ['convnet2', 'convnet5', 'vgg11']:
|
|
from federatedscope.cv.model import get_cnn
|
|
model = get_cnn(model_config, input_shape)
|
|
elif model_config.type.lower() in [
|
|
'simclr', 'simclr_linear', "supervised_local", "supervised_fedavg"
|
|
]:
|
|
from federatedscope.cl.model import get_simclr
|
|
model = get_simclr(model_config, input_shape)
|
|
if model_config.type.lower().endswith('linear'):
|
|
for name, value in model.named_parameters():
|
|
if not name.startswith('linear'):
|
|
value.requires_grad = False
|
|
elif model_config.type.lower() in ['lstm']:
|
|
from federatedscope.nlp.model import get_rnn
|
|
model = get_rnn(model_config, input_shape)
|
|
elif model_config.type.lower().endswith('transformers'):
|
|
from federatedscope.nlp.model import get_transformer
|
|
model = get_transformer(model_config, input_shape)
|
|
elif model_config.type.lower() in [
|
|
'gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn'
|
|
]:
|
|
from federatedscope.gfl.model import get_gnn
|
|
model = get_gnn(model_config, input_shape)
|
|
elif model_config.type.lower() in ['vmfnet', 'hmfnet']:
|
|
from federatedscope.mf.model.model_builder import get_mfnet
|
|
model = get_mfnet(model_config, input_shape)
|
|
elif model_config.type.lower() in [
|
|
'xgb_tree', 'gbdt_tree', 'random_forest'
|
|
]:
|
|
from federatedscope.vertical_fl.tree_based_models.model.model_builder \
|
|
import get_tree_model
|
|
model = get_tree_model(model_config)
|
|
elif model_config.type.lower() in ['atc_model']:
|
|
from federatedscope.nlp.hetero_tasks.model import ATCModel
|
|
model = ATCModel(model_config)
|
|
elif model_config.type.lower() in ['feddgcn']:
|
|
from federatedscope.trafficflow.model.FedDGCN import FedDGCN
|
|
model = FedDGCN(model_config)
|
|
else:
|
|
raise ValueError('Model {} is not provided'.format(model_config.type))
|
|
|
|
return model
|
|
|
|
|
|
def get_trainable_para_names(model):
|
|
grad_params = set()
|
|
for name, param in model.named_parameters():
|
|
if param.requires_grad:
|
|
grad_params.add(name)
|
|
return grad_params
|