FS-TFP/federatedscope/core/trainers
HengZhang 6ea133716f modifications on original FS
modifications on original FS
2024-11-21 12:37:27 +08:00
..
README.md The origin version of FederatedScope 2024-11-21 11:25:24 +08:00
__init__.py The origin version of FederatedScope 2024-11-21 11:25:24 +08:00
base_trainer.py The origin version of FederatedScope 2024-11-21 11:25:24 +08:00
context.py The origin version of FederatedScope 2024-11-21 11:25:24 +08:00
enums.py The origin version of FederatedScope 2024-11-21 11:25:24 +08:00
tf_trainer.py The origin version of FederatedScope 2024-11-21 11:25:24 +08:00
torch_trainer.py modifications on original FS 2024-11-21 12:37:27 +08:00
trainer.py The origin version of FederatedScope 2024-11-21 11:25:24 +08:00
trainer_Ditto.py The origin version of FederatedScope 2024-11-21 11:25:24 +08:00
trainer_FedEM.py modifications on original FS 2024-11-21 12:37:27 +08:00
trainer_FedRep.py The origin version of FederatedScope 2024-11-21 11:25:24 +08:00
trainer_fedprox.py The origin version of FederatedScope 2024-11-21 11:25:24 +08:00
trainer_multi_model.py modifications on original FS 2024-11-21 12:37:27 +08:00
trainer_nbafl.py The origin version of FederatedScope 2024-11-21 11:25:24 +08:00
trainer_pFedMe.py The origin version of FederatedScope 2024-11-21 11:25:24 +08:00
trainer_simple_tuning.py The origin version of FederatedScope 2024-11-21 11:25:24 +08:00
utils.py The origin version of FederatedScope 2024-11-21 11:25:24 +08:00

README.md

Local Learning Abstraction: Trainer

FederatedScope decouples the local learning process and details of FL communication and schedule, allowing users to freely customize the local learning algorithm via the trainer. Each worker holds a trainer object to manage the details of local learning, such as the loss function, optimizer, training step, evaluation, etc.

This tutorial is a shorter version of full version tutorial, where you can learn more details about FS Trainer.

Code Structure

The code structure is shown below, and we will discuss all the concepts of our FS Trainer later.

federatedscope/core
├── trainers
│   ├── BaseTrainer
│   │   ├── Trainer
│   │   │   ├── GeneralTorchTrainer
│   │   │   ├── GeneralTFTrainer
│   │   │   ├── Context
│   │   │   ├── ...
│   │   ├── UserDefineTrainer
│   │   ├── ...

FS Trainer

A typical machine-learning process consists of the following procedures:

  1. Preparing data.
  2. Iterations over training datasets to update the model parameters
  3. Evaluation of the quality of the learned model on validation/evaluation datasets
  4. Saving, loading, and monitoring the model and intermediate results

BaseTrainer

BaseTrainer is an abstract class of our Trainer, which provide the interface of each method. And you can implement your own trainer by inheriting from BaseTrainer. More examples can be found in federatedscope/contrib/trainer.

class BaseTrainer(abc.ABC):
    def __init__(self, model, data, device, **kwargs):
        self.model = model
        self.data = data
        self.device = device
        self.kwargs = kwargs

    @abc.abstractmethod
    def train(self):
        raise NotImplementedError

    @abc.abstractmethod
    def evaluate(self, target_data_split_name='test'):
        raise NotImplementedError

    @abc.abstractmethod
    def update(self, model_parameters, strict=False):
        raise NotImplementedError

    @abc.abstractmethod
    def get_model_para(self):
        raise NotImplementedError
    
    ... ...

Trainer

As the figure shows, in FederatedScope Trainer (a subclass of BaseTrainer), these above procedures are provided with high-level routines abstraction, which is made up of Context class and several pluggable Hooks. And we provide GeneralTorchTrainer and GeneralTFTrainer for PyTorch and TensorFlow, separately.

undefined

Context

The Context class (a subclass of dict) is used to hold learning-related attributes, including data, model, optimizer and etc, and user and add or delete these attributes in hook functions. We classify and show the default attributes below:

  • Data-related attributes
    • ctx.data: the raw data (not split) the trainer holds
    • ctx.num_samples: the number of samples used in training
    • ctx.train_data, ctx.val_data, ctx.test_data: the split data the trainer holds
    • ctx.train_loader, ctx.val_loader, ctx.test_loader: the DataLoader of each split data
    • ctx.num_train_data, ctx.num_val_data, ctx.num_test_data: the number of samples of the split data
  • Model-related attributes
    • ctx.model: the model used
    • ctx.models: the multi models if use
    • ctx.mirrored_models: the mirrored models
    • ctx.trainable_para_names: the trainable parameter names of the model
  • Optimizer-related attributes
    • ctx.optimizer: see torch.optim for details
    • ctx.scheduler: decays the learning rate of each parameter group
    • ctx.criterion: loss/criterion function
    • ctx.regularizer: regular terms
    • ctx.grad_clip: gradient clipping
  • Mode-related attributes
    • ctx.cur_mode: mode of trainer, which is one of ['train', 'val', 'test']
    • ctx.mode_stack: stack of mode, only used for switching mode
    • ctx.cur_split: split of data, which is one of ['train', 'val', 'test'] (Note: use train data in test mode is allowed)
    • ctx.split_stack: stack of split, only used for switching data split
  • Metric-related attributes
    • ctx.loss_batch_total: Loss of current batch
    • ctx.loss_regular_total: Loss of regular term
    • ctx.y_true: true label of batch data
    • ctx.y_prob: output of the model with batch data as input
    • ctx.ys_true: true label of data
    • ctx.ys_prob: output of the model
    • ctx.eval_metrics: evaluation metrics calculated by Monitor
    • ctx.monitor: used for monitor trainer's behavior and statistics
  • Other (statistics) attributes (@property, query from cfg if not set)
    • ctx.cfg: configuration of FL course, see link for details
    • ctx.device: current device, such as cpu and gpu0.
    • ctx.num_train_batch_last_epoch, ctx.num_total_train_batch: the number of batch
    • ctx.num_train_epoch, ctx.num_val_epoch, ctx.num_test_epoch: the number of epoch in each data split
    • ctx.num_train_batch, ctx.num_val_batch, ctx.num_test_batch: the number of batch in each data split

Hooks

The Hooks represent fine-grained learning behaviors at different point-in-times, which provides a simple yet powerful way to customize learning behaviors with a few modifications and easy re-use of fruitful default hooks. In this section, we will show the detail of each hook used in Trainer.

Hook trigger

The hook trigger is where the hook functions are executed, and all the hook functions are executed following the pattern below:

  • on_fit_start
    • on_epoch_start
      • on_batch_start
      • on_batch_forward
      • on_batch_backward
      • on_batch_end
    • on_epoch_end
  • on_fit_end
Train hooks

Train hooks are executed when ctx.cur_mode is train, following the execution paradigm as shown below:

  • on_fit_start

    _hook_on_fit_start_init

    _hook_on_fit_start_calculate_model_size

    • on_epoch_start

      _hook_on_epoch_start

      • on_batch_start

        _hook_on_batch_start_init

      • on_batch_forward

        _hook_on_batch_forward

        _hook_on_batch_forward_regularizer

        _hook_on_batch_forward_flop_count

      • on_batch_backward

        _hook_on_batch_backward

      • on_batch_end

        _hook_on_batch_end

    • on_epoch_end

      None

  • on_fit_end

    _hook_on_fit_end

Evaluation (val/test) hooks

Evaluation hooks are executed when ctx.cur_mode is val or test, following the execution paradigm as shown below:

  • on_fit_start

    _hook_on_fit_start_init

    • on_epoch_start

      _hook_on_epoch_start

      • on_batch_start

        _hook_on_batch_start_init

      • on_batch_forward

        _hook_on_batch_forward

      • on_batch_backward

        None

      • on_batch_end

        _hook_on_batch_end

    • on_epoch_end

      None

  • on_fit_end

    _hook_on_fit_end

Finetune hooks

Finetune hooks are executed when ctx.cur_mode is finetune, following the execution paradigm as shown below:

  • on_fit_start

    _hook_on_fit_start_init

    _hook_on_fit_start_calculate_model_size

    • on_epoch_start

      _hook_on_epoch_start

      • on_batch_start

        _hook_on_batch_start_init

      • on_batch_forward

        _hook_on_batch_forward

        _hook_on_batch_forward_regularizer

        _hook_on_batch_forward_flop_count

      • on_batch_backward

        _hook_on_batch_backward

      • on_batch_end

        _hook_on_batch_end

    • on_epoch_end

      None

  • on_fit_end

    _hook_on_fit_end

Hook functions

In this section, we will briefly describe what the hook functions do with the attributes/variables in ctx.

GeneralTorchTrainer
  • _hook_on_fit_start_init

    Modified attribute Operation
    ctx.model Move to ctx.device
    ctx.optimizer Initialize by ctx.cfg
    ctx.scheduler Initialize by ctx.cfg
    ctx.loss_batch_total Initialize to 0
    ctx.loss_regular_total Initialize to 0
    ctx.num_samples Initialize to 0
    ctx.ys_true Initialize to []
    ctx.ys_prob Initialize to []
  • _hook_on_fit_start_calculate_model_size

    Modified attribute Operation
    ctx.monitor Track model size
  • _hook_on_epoch_start

    Modified attribute Operation
    ctx.{ctx.cur_split}_loader Initialize DataLoader
  • _hook_on_batch_start_init

    Modified attribute Operation
    ctx.data_batch Initialize batch data
  • _hook_on_batch_forward

    Modified attribute Operation
    ctx.y_true Move to ctx.device
    ctx.y_prob Forward propagation to get y_prob
    ctx.loss_batch Calculate the loss
    ctx.batch_size Get the batch_size
  • _hook_on_batch_forward_regularizer

    Modified attribute Operation
    ctx.loss_regular Calculate the regular loss
    ctx.loss_task Sum the ctx.loss_regular and ctx.loss
  • _hook_on_batch_forward_flop_count

    Modified attribute Operation
    ctx.monitor Track average flops
  • _hook_on_batch_backward

    Modified attribute Operation
    ctx.optimizer Update by gradient
    ctx.loss_task Backward propagation
    ctx.scheduler Update by gradient
  • _hook_on_batch_end

    Modified attribute Operation
    ctx.num_samples Add ctx.batch_size
    ctx.loss_batch_total Add batch loss
    ctx.loss_regular_total Add batch regular loss
    ctx.ys_true Append ctx.y_true
    ctx.ys_prob Append ctx.ys_prob
  • _hook_on_fit_end

    Modified attribute Operation
    ctx.ys_true Convert to numpy.array
    ctx.ys_prob Convert to numpy.array
    ctx.monitor Evaluate the results
    ctx.eval_metrics Get evaluated results from ctx.monitor
DittoTrainer
  • _hook_on_fit_start_set_regularized_para

    Modified attribute Operation
    ctx.global_model Move to ctx.device and set to train mode
    ctx.local_model Move to ctx.device and set to train mode
    ctx.optimizer_for_global_model Initialize by ctx.cfg and wrapped by wrap_regularized_optimizer
    ctx.optimizer_for_local_model Initialize by ctx.cfg and set compared parameter group
  • _hook_on_fit_start_clean

    Modified attribute Operation
    ctx.optimizer Delete
    ctx.num_samples_local_model_train Initialize to 0
  • _hook_on_fit_start_switch_local_model

    Modified attribute Operation
    ctx.model Set to ctx.local_model and set to eval mode
  • _hook_on_batch_start_switch_model

    Modified attribute Operation
    ctx.use_local_model_current Set to True or False
    ctx.model Set to ctx.local_model or ctx.global_model
    ctx.optimizer Set to ctx.optimizer_for_local_model or ctx.optimizer_for_global_model
  • _hook_on_batch_forward_cnt_num

    Modified attribute Operation
    ctx.num_samples_local_model_train Add ctx.batch_size
  • _hook_on_batch_end_flop_count

    Modified attribute Operation
    ctx.monitor Monitor total flops
  • _hook_on_fit_end_calibrate

    Modified attribute Operation
    ctx.num_samples Minus ctx.num_samples_local_model_train
    ctx.eval_metrics Record train_total and train_total_local_model
  • _hook_on_fit_end_switch_global_model

    Modified attribute Operation
    ctx.model Set to ctx.global_model
  • _hook_on_fit_end_free_cuda

    Modified attribute Operation
    ctx.global_model Move to cpu
    ctx.local_model Move to cpu
pFedMeTrainer
  • _hook_on_fit_start_set_local_para_tmp

    Modified attribute Operation
    ctx.optimizer Wrapped by wrap_regularized_optimizer and set compared parameter group
    ctx.pFedMe_outer_lr Initialize to ctx.cfg.train.optimizer.lr
    ctx.pFedMe_local_model_tmp Copy from ctx.model
  • _hook_on_batch_start_init_pfedme

    Modified attribute Operation
    ctx.data_batch_cache Copy from ctx.data_batch
    ctx.pFedMe_approx_fit_counter Count to refresh data every K step
  • _hook_on_batch_end_flop_count

    Modified attribute Operation
    ctx.monitor Monitor total flops
  • _hook_on_epoch_end_flop_count

    Modified attribute Operation
    ctx.monitor Monitor total flops
  • _hook_on_epoch_end_update_local

    Modified attribute Operation
    ctx.model Update parameters by ctx.pFedMe_local_model_tmp
    ctx.optimizer Set compared parameter group
  • _hook_on_fit_end_update_local

    Modified attribute Operation
    ctx.model Update parameters by ctx.pFedMe_local_model_tmp
    ctx.pFedMe_local_model_tmp Delete
FedProxTrainer & NbaflTrainer
  • _hook_record_initialization

    Modified attribute Operation
    ctx.weight_init Copy from ctx.model
  • _hook_del_initialization

    Modified attribute Operation
    ctx.weight_init Set to None
  • _hook_inject_noise_in_upload

    Modified attribute Operation
    ctx.model Inject noise to parameters