267 lines
12 KiB
Python
267 lines
12 KiB
Python
import logging
|
|
import importlib
|
|
|
|
import federatedscope.register as register
|
|
from federatedscope.core.trainers import Trainer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
from federatedscope.contrib.trainer import *
|
|
except ImportError as error:
|
|
logger.warning(
|
|
f'{error} in `federatedscope.contrib.trainer`, some modules are not '
|
|
f'available.')
|
|
|
|
TRAINER_CLASS_DICT = {
|
|
"cvtrainer": "CVTrainer",
|
|
"nlptrainer": "NLPTrainer",
|
|
"graphminibatch_trainer": "GraphMiniBatchTrainer",
|
|
"linkfullbatch_trainer": "LinkFullBatchTrainer",
|
|
"linkminibatch_trainer": "LinkMiniBatchTrainer",
|
|
"nodefullbatch_trainer": "NodeFullBatchTrainer",
|
|
"nodeminibatch_trainer": "NodeMiniBatchTrainer",
|
|
"flitplustrainer": "FLITPlusTrainer",
|
|
"flittrainer": "FLITTrainer",
|
|
"fedvattrainer": "FedVATTrainer",
|
|
"fedfocaltrainer": "FedFocalTrainer",
|
|
"mftrainer": "MFTrainer",
|
|
"cltrainer": "CLTrainer",
|
|
"lptrainer": "LPTrainer",
|
|
"atc_trainer": "ATCTrainer",
|
|
}
|
|
|
|
|
|
def get_trainer(model=None,
|
|
data=None,
|
|
device=None,
|
|
config=None,
|
|
only_for_eval=False,
|
|
is_attacker=False,
|
|
monitor=None):
|
|
"""
|
|
This function builds an instance of trainer.
|
|
|
|
Arguments:
|
|
model: model used in FL course
|
|
data: data used in FL course
|
|
device: where to train model (``cpu`` or ``gpu``)
|
|
config: configurations for FL, see ``federatedscope.core.configs``
|
|
only_for_eval: ``True`` or ``False``, if ``True``, ``train`` \
|
|
routine will be removed in this trainer
|
|
is_attacker: ``True`` or ``False`` to determine whether this client \
|
|
is an attacker
|
|
monitor: an instance of ``federatedscope.core.monitors.Monitor`` to \
|
|
observe the evaluation and system metrics
|
|
|
|
Returns:
|
|
An instance of trainer.
|
|
|
|
Note:
|
|
The key-value pairs of ``cfg.trainer.type`` and trainers:
|
|
================================== ===========================
|
|
Trainer Type Source
|
|
================================== ===========================
|
|
``general`` \
|
|
``core.trainers.GeneralTorchTrainer`` and \
|
|
``core.trainers.GeneralTFTrainer``
|
|
``cvtrainer`` ``cv.trainer.trainer.CVTrainer``
|
|
``nlptrainer`` ``nlp.trainer.trainer.NLPTrainer``
|
|
``graphminibatch_trainer`` \
|
|
``gfl.trainer.graphtrainer.GraphMiniBatchTrainer``
|
|
``linkfullbatch_trainer`` \
|
|
``gfl.trainer.linktrainer.LinkFullBatchTrainer``
|
|
``linkminibatch_trainer`` \
|
|
``gfl.trainer.linktrainer.LinkMiniBatchTrainer``
|
|
``nodefullbatch_trainer`` \
|
|
``gfl.trainer.nodetrainer.NodeFullBatchTrainer``
|
|
``nodeminibatch_trainer`` \
|
|
``gfl.trainer.nodetrainer.NodeMiniBatchTrainer``
|
|
``flitplustrainer`` \
|
|
``gfl.flitplus.trainer.FLITPlusTrainer``
|
|
``flittrainer`` \
|
|
``gfl.flitplus.trainer.FLITTrainer``
|
|
``fedvattrainer`` \
|
|
``gfl.flitplus.trainer.FedVATTrainer``
|
|
``fedfocaltrainer`` \
|
|
``gfl.flitplus.trainer.FedFocalTrainer``
|
|
``mftrainer`` \
|
|
``federatedscope.mf.trainer.MFTrainer``
|
|
``mytorchtrainer`` \
|
|
``contrib.trainer.torch_example.MyTorchTrainer``
|
|
================================== ===========================
|
|
Wrapper functions are shown below:
|
|
================================== ===========================
|
|
Wrapper Functions Source
|
|
================================== ===========================
|
|
``nbafl`` \
|
|
``core.trainers.wrap_nbafl_trainer``
|
|
``sgdmf`` ``mf.trainer.wrap_MFTrainer``
|
|
``pfedme`` \
|
|
``core.trainers.wrap_pFedMeTrainer``
|
|
``ditto`` ``core.trainers.wrap_DittoTrainer``
|
|
``fedem`` ``core.trainers.FedEMTrainer``
|
|
``fedprox`` \
|
|
``core.trainers.wrap_fedprox_trainer``
|
|
``attack`` \
|
|
``attack.trainer.wrap_benignTrainer`` and \
|
|
``attack.auxiliary.attack_trainer_builder.wrap_attacker_trainer``
|
|
================================== ===========================
|
|
"""
|
|
if config.trainer.type == 'general':
|
|
if config.backend == 'torch':
|
|
from federatedscope.core.trainers import GeneralTorchTrainer
|
|
trainer = GeneralTorchTrainer(model=model,
|
|
data=data,
|
|
device=device,
|
|
config=config,
|
|
only_for_eval=only_for_eval,
|
|
monitor=monitor)
|
|
elif config.backend == 'tensorflow':
|
|
from federatedscope.core.trainers import GeneralTFTrainer
|
|
trainer = GeneralTFTrainer(model=model,
|
|
data=data,
|
|
device=device,
|
|
config=config,
|
|
only_for_eval=only_for_eval,
|
|
monitor=monitor)
|
|
else:
|
|
raise ValueError
|
|
elif config.trainer.type == 'none':
|
|
return None
|
|
elif config.trainer.type.lower() in TRAINER_CLASS_DICT:
|
|
if config.trainer.type.lower() in ['cvtrainer']:
|
|
dict_path = "federatedscope.cv.trainer.trainer"
|
|
elif config.trainer.type.lower() in ['nlptrainer']:
|
|
dict_path = "federatedscope.nlp.trainer.trainer"
|
|
elif config.trainer.type.lower() in ['cltrainer', 'lptrainer']:
|
|
dict_path = "federatedscope.cl.trainer.trainer"
|
|
elif config.trainer.type.lower() in [
|
|
'graphminibatch_trainer',
|
|
]:
|
|
dict_path = "federatedscope.gfl.trainer.graphtrainer"
|
|
elif config.trainer.type.lower() in [
|
|
'linkfullbatch_trainer', 'linkminibatch_trainer'
|
|
]:
|
|
dict_path = "federatedscope.gfl.trainer.linktrainer"
|
|
elif config.trainer.type.lower() in [
|
|
'nodefullbatch_trainer', 'nodeminibatch_trainer'
|
|
]:
|
|
dict_path = "federatedscope.gfl.trainer.nodetrainer"
|
|
elif config.trainer.type.lower() in [
|
|
'flitplustrainer', 'flittrainer', 'fedvattrainer',
|
|
'fedfocaltrainer'
|
|
]:
|
|
dict_path = "federatedscope.gfl.flitplus.trainer"
|
|
elif config.trainer.type.lower() in ['mftrainer']:
|
|
dict_path = "federatedscope.mf.trainer.trainer"
|
|
elif config.trainer.type.lower() in ['atc_trainer']:
|
|
dict_path = "federatedscope.nlp.hetero_tasks.trainer"
|
|
else:
|
|
raise ValueError
|
|
|
|
trainer_cls = getattr(importlib.import_module(name=dict_path),
|
|
TRAINER_CLASS_DICT[config.trainer.type.lower()])
|
|
trainer = trainer_cls(model=model,
|
|
data=data,
|
|
device=device,
|
|
config=config,
|
|
only_for_eval=only_for_eval,
|
|
monitor=monitor)
|
|
elif config.trainer.type.lower() in ['verticaltrainer']:
|
|
from federatedscope.vertical_fl.tree_based_models.trainer.utils \
|
|
import get_vertical_trainer
|
|
trainer = get_vertical_trainer(config=config,
|
|
model=model,
|
|
data=data,
|
|
device=device,
|
|
monitor=monitor)
|
|
# Add traffic flow trainer
|
|
elif config.trainer.type.lower() in ['trafficflowtrainer']:
|
|
from federatedscope.trafficflow.trainer.trafficflow_trainer import call_trafficflow_trainer
|
|
trainer = call_trafficflow_trainer(config=config,
|
|
model=model,
|
|
data=data,
|
|
device=device,
|
|
monitor=monitor)
|
|
else:
|
|
# try to find user registered trainer
|
|
trainer = None
|
|
for func in register.trainer_dict.values():
|
|
trainer_cls = func(config.trainer.type)
|
|
if trainer_cls is not None:
|
|
trainer = trainer_cls(model=model,
|
|
data=data,
|
|
device=device,
|
|
config=config,
|
|
only_for_eval=only_for_eval,
|
|
monitor=monitor)
|
|
if trainer is None:
|
|
raise ValueError('Trainer {} is not provided'.format(
|
|
config.trainer.type))
|
|
|
|
if not isinstance(trainer, Trainer):
|
|
logger.warning(f'Hook-like plug-in functions cannot be enabled when '
|
|
f'using {trainer}. If you want use our wrapper '
|
|
f'functions for your trainer please consider '
|
|
f'inheriting from '
|
|
f'`federatedscope.core.trainers.Trainer` instead.')
|
|
return trainer
|
|
|
|
# differential privacy plug-in
|
|
if config.nbafl.use:
|
|
from federatedscope.core.trainers import wrap_nbafl_trainer
|
|
trainer = wrap_nbafl_trainer(trainer)
|
|
if config.sgdmf.use:
|
|
from federatedscope.mf.trainer import wrap_MFTrainer
|
|
trainer = wrap_MFTrainer(trainer)
|
|
|
|
# personalization plug-in
|
|
if config.federate.method.lower() == "pfedme":
|
|
from federatedscope.core.trainers import wrap_pFedMeTrainer
|
|
# wrap style: instance a (class A) -> instance a (class A)
|
|
trainer = wrap_pFedMeTrainer(trainer)
|
|
elif config.federate.method.lower() == "ditto":
|
|
from federatedscope.core.trainers import wrap_DittoTrainer
|
|
# wrap style: instance a (class A) -> instance a (class A)
|
|
trainer = wrap_DittoTrainer(trainer)
|
|
elif config.federate.method.lower() == "fedem":
|
|
from federatedscope.core.trainers import FedEMTrainer
|
|
# copy construct style: instance a (class A) -> instance b (class B)
|
|
trainer = FedEMTrainer(model_nums=config.model.model_num_per_trainer,
|
|
base_trainer=trainer)
|
|
elif config.federate.method.lower() == "fedrep":
|
|
from federatedscope.core.trainers import wrap_FedRepTrainer
|
|
# wrap style: instance a (class A) -> instance a (class A)
|
|
trainer = wrap_FedRepTrainer(trainer)
|
|
|
|
# attacker plug-in
|
|
if 'backdoor' in config.attack.attack_method:
|
|
from federatedscope.attack.trainer import wrap_benignTrainer
|
|
trainer = wrap_benignTrainer(trainer)
|
|
|
|
if is_attacker:
|
|
if 'backdoor' in config.attack.attack_method:
|
|
logger.info('--------This client is a backdoor attacker --------')
|
|
else:
|
|
logger.info('-------- This client is an privacy attacker --------')
|
|
from federatedscope.attack.auxiliary.attack_trainer_builder \
|
|
import wrap_attacker_trainer
|
|
trainer = wrap_attacker_trainer(trainer, config)
|
|
|
|
elif 'backdoor' in config.attack.attack_method:
|
|
logger.info(
|
|
'----- This client is a benign client for backdoor attacks -----')
|
|
|
|
# fed algorithm plug-in
|
|
if config.fedprox.use:
|
|
from federatedscope.core.trainers import wrap_fedprox_trainer
|
|
trainer = wrap_fedprox_trainer(trainer)
|
|
|
|
# different fine-tuning
|
|
if config.finetune.before_eval and config.finetune.simple_tuning:
|
|
from federatedscope.core.trainers import wrap_Simple_tuning_Trainer
|
|
trainer = wrap_Simple_tuning_Trainer(trainer)
|
|
|
|
return trainer
|