16 KiB
Local Learning Abstraction: Trainer
FederatedScope decouples the local learning process and details of FL communication and schedule, allowing users to freely customize the local learning algorithm via the trainer. Each worker holds a trainer object to manage the details of local learning, such as the loss function, optimizer, training step, evaluation, etc.
This tutorial is a shorter version of full version tutorial, where you can learn more details about FS Trainer.
Code Structure
The code structure is shown below, and we will discuss all the concepts of our FS Trainer later.
federatedscope/core
├── trainers
│ ├── BaseTrainer
│ │ ├── Trainer
│ │ │ ├── GeneralTorchTrainer
│ │ │ ├── GeneralTFTrainer
│ │ │ ├── Context
│ │ │ ├── ...
│ │ ├── UserDefineTrainer
│ │ ├── ...
FS Trainer
A typical machine-learning process consists of the following procedures:
- Preparing data.
- Iterations over training datasets to update the model parameters
- Evaluation of the quality of the learned model on validation/evaluation datasets
- Saving, loading, and monitoring the model and intermediate results
BaseTrainer
BaseTrainer is an abstract class of our Trainer, which provide the interface of each method. And you can implement your own trainer by inheriting from BaseTrainer. More examples can be found in federatedscope/contrib/trainer.
class BaseTrainer(abc.ABC):
def __init__(self, model, data, device, **kwargs):
self.model = model
self.data = data
self.device = device
self.kwargs = kwargs
@abc.abstractmethod
def train(self):
raise NotImplementedError
@abc.abstractmethod
def evaluate(self, target_data_split_name='test'):
raise NotImplementedError
@abc.abstractmethod
def update(self, model_parameters, strict=False):
raise NotImplementedError
@abc.abstractmethod
def get_model_para(self):
raise NotImplementedError
... ...
Trainer
As the figure shows, in FederatedScope Trainer (a subclass of BaseTrainer), these above procedures are provided with high-level routines abstraction, which is made up of Context class and several pluggable Hooks. And we provide GeneralTorchTrainer and GeneralTFTrainer for PyTorch and TensorFlow, separately.
Context
The Context class (a subclass of dict) is used to hold learning-related attributes, including data, model, optimizer and etc, and user and add or delete these attributes in hook functions. We classify and show the default attributes below:
- Data-related attributes
ctx.data: the raw data (not split) the trainer holdsctx.num_samples: the number of samples used in trainingctx.train_data,ctx.val_data,ctx.test_data: the split data the trainer holdsctx.train_loader,ctx.val_loader,ctx.test_loader: the DataLoader of each split datactx.num_train_data,ctx.num_val_data,ctx.num_test_data: the number of samples of the split data
- Model-related attributes
ctx.model: the model usedctx.models: the multi models if usectx.mirrored_models: the mirrored modelsctx.trainable_para_names: the trainable parameter names of the model
- Optimizer-related attributes
ctx.optimizer: seetorch.optimfor detailsctx.scheduler: decays the learning rate of each parameter groupctx.criterion: loss/criterion functionctx.regularizer: regular termsctx.grad_clip: gradient clipping
- Mode-related attributes
ctx.cur_mode: mode of trainer, which is one of['train', 'val', 'test']ctx.mode_stack: stack of mode, only used for switching modectx.cur_split: split of data, which is one of['train', 'val', 'test'](Note: usetraindata intestmode is allowed)ctx.split_stack: stack of split, only used for switching data split
- Metric-related attributes
ctx.loss_batch_total: Loss of current batchctx.loss_regular_total: Loss of regular termctx.y_true: true label of batch datactx.y_prob: output of the model with batch data as inputctx.ys_true: true label of datactx.ys_prob: output of the modelctx.eval_metrics: evaluation metrics calculated byMonitorctx.monitor: used for monitor trainer's behavior and statistics
- Other (statistics) attributes (@property, query from
cfgif not set)ctx.cfg: configuration of FL course, see link for detailsctx.device: current device, such ascpuandgpu0.ctx.num_train_batch_last_epoch,ctx.num_total_train_batch: the number of batchctx.num_train_epoch,ctx.num_val_epoch,ctx.num_test_epoch: the number of epoch in each data splitctx.num_train_batch,ctx.num_val_batch,ctx.num_test_batch: the number of batch in each data split
Hooks
The Hooks represent fine-grained learning behaviors at different point-in-times, which provides a simple yet powerful way to customize learning behaviors with a few modifications and easy re-use of fruitful default hooks. In this section, we will show the detail of each hook used in Trainer.
Hook trigger
The hook trigger is where the hook functions are executed, and all the hook functions are executed following the pattern below:
- on_fit_start
- on_epoch_start
- on_batch_start
- on_batch_forward
- on_batch_backward
- on_batch_end
- on_epoch_end
- on_epoch_start
- on_fit_end
Train hooks
Train hooks are executed when ctx.cur_mode is train, following the execution paradigm as shown below:
-
on_fit_start
_hook_on_fit_start_init_hook_on_fit_start_calculate_model_size-
on_epoch_start
_hook_on_epoch_start-
on_batch_start
_hook_on_batch_start_init -
on_batch_forward
_hook_on_batch_forward_hook_on_batch_forward_regularizer_hook_on_batch_forward_flop_count -
on_batch_backward
_hook_on_batch_backward -
on_batch_end
_hook_on_batch_end
-
-
on_epoch_end
None
-
-
on_fit_end
_hook_on_fit_end
Evaluation (val/test) hooks
Evaluation hooks are executed when ctx.cur_mode is val or test, following the execution paradigm as shown below:
-
on_fit_start
_hook_on_fit_start_init-
on_epoch_start
_hook_on_epoch_start-
on_batch_start
_hook_on_batch_start_init -
on_batch_forward
_hook_on_batch_forward -
on_batch_backward
None -
on_batch_end
_hook_on_batch_end
-
-
on_epoch_end
None
-
-
on_fit_end
_hook_on_fit_end
Finetune hooks
Finetune hooks are executed when ctx.cur_mode is finetune, following the execution paradigm as shown below:
-
on_fit_start
_hook_on_fit_start_init_hook_on_fit_start_calculate_model_size-
on_epoch_start
_hook_on_epoch_start-
on_batch_start
_hook_on_batch_start_init -
on_batch_forward
_hook_on_batch_forward_hook_on_batch_forward_regularizer_hook_on_batch_forward_flop_count -
on_batch_backward
_hook_on_batch_backward -
on_batch_end
_hook_on_batch_end
-
-
on_epoch_end
None
-
-
on_fit_end
_hook_on_fit_end
Hook functions
In this section, we will briefly describe what the hook functions do with the attributes/variables in ctx.
GeneralTorchTrainer
-
_hook_on_fit_start_initModified attribute Operation ctx.modelMove to ctx.devicectx.optimizerInitialize by ctx.cfgctx.schedulerInitialize by ctx.cfgctx.loss_batch_totalInitialize to 0ctx.loss_regular_totalInitialize to 0ctx.num_samplesInitialize to 0ctx.ys_trueInitialize to []ctx.ys_probInitialize to [] -
_hook_on_fit_start_calculate_model_sizeModified attribute Operation ctx.monitorTrack model size -
_hook_on_epoch_startModified attribute Operation ctx.{ctx.cur_split}_loaderInitialize DataLoader -
_hook_on_batch_start_initModified attribute Operation ctx.data_batchInitialize batch data -
_hook_on_batch_forwardModified attribute Operation ctx.y_trueMove to ctx.devicectx.y_probForward propagation to get y_probctx.loss_batchCalculate the loss ctx.batch_sizeGet the batch_size -
_hook_on_batch_forward_regularizerModified attribute Operation ctx.loss_regularCalculate the regular loss ctx.loss_taskSum the ctx.loss_regularandctx.loss -
_hook_on_batch_forward_flop_countModified attribute Operation ctx.monitorTrack average flops -
_hook_on_batch_backwardModified attribute Operation ctx.optimizerUpdate by gradient ctx.loss_taskBackward propagation ctx.schedulerUpdate by gradient -
_hook_on_batch_endModified attribute Operation ctx.num_samplesAdd ctx.batch_sizectx.loss_batch_totalAdd batch loss ctx.loss_regular_totalAdd batch regular loss ctx.ys_trueAppend ctx.y_truectx.ys_probAppend ctx.ys_prob -
_hook_on_fit_endModified attribute Operation ctx.ys_trueConvert to numpy.arrayctx.ys_probConvert to numpy.arrayctx.monitorEvaluate the results ctx.eval_metricsGet evaluated results from ctx.monitor
DittoTrainer
-
_hook_on_fit_start_set_regularized_paraModified attribute Operation ctx.global_modelMove to ctx.deviceand set totrainmodectx.local_modelMove to ctx.deviceand set totrainmodectx.optimizer_for_global_modelInitialize by ctx.cfgand wrapped bywrap_regularized_optimizerctx.optimizer_for_local_modelInitialize by ctx.cfgand set compared parameter group -
_hook_on_fit_start_cleanModified attribute Operation ctx.optimizerDelete ctx.num_samples_local_model_trainInitialize to 0 -
_hook_on_fit_start_switch_local_modelModified attribute Operation ctx.modelSet to ctx.local_modeland set toevalmode -
_hook_on_batch_start_switch_modelModified attribute Operation ctx.use_local_model_currentSet to TrueorFalsectx.modelSet to ctx.local_modelorctx.global_modelctx.optimizerSet to ctx.optimizer_for_local_modelorctx.optimizer_for_global_model -
_hook_on_batch_forward_cnt_numModified attribute Operation ctx.num_samples_local_model_trainAdd ctx.batch_size -
_hook_on_batch_end_flop_countModified attribute Operation ctx.monitorMonitor total flops -
_hook_on_fit_end_calibrateModified attribute Operation ctx.num_samplesMinus ctx.num_samples_local_model_trainctx.eval_metricsRecord train_totalandtrain_total_local_model -
_hook_on_fit_end_switch_global_modelModified attribute Operation ctx.modelSet to ctx.global_model -
_hook_on_fit_end_free_cudaModified attribute Operation ctx.global_modelMove to cpuctx.local_modelMove to cpu
pFedMeTrainer
-
_hook_on_fit_start_set_local_para_tmpModified attribute Operation ctx.optimizerWrapped by wrap_regularized_optimizerand set compared parameter groupctx.pFedMe_outer_lrInitialize to ctx.cfg.train.optimizer.lrctx.pFedMe_local_model_tmpCopy from ctx.model -
_hook_on_batch_start_init_pfedmeModified attribute Operation ctx.data_batch_cacheCopy from ctx.data_batchctx.pFedMe_approx_fit_counterCount to refresh data every K step -
_hook_on_batch_end_flop_countModified attribute Operation ctx.monitorMonitor total flops -
_hook_on_epoch_end_flop_countModified attribute Operation ctx.monitorMonitor total flops -
_hook_on_epoch_end_update_localModified attribute Operation ctx.modelUpdate parameters by ctx.pFedMe_local_model_tmpctx.optimizerSet compared parameter group -
_hook_on_fit_end_update_localModified attribute Operation ctx.modelUpdate parameters by ctx.pFedMe_local_model_tmpctx.pFedMe_local_model_tmpDelete
FedProxTrainer & NbaflTrainer
-
_hook_record_initializationModified attribute Operation ctx.weight_initCopy from ctx.model -
_hook_del_initializationModified attribute Operation ctx.weight_initSet to None -
_hook_inject_noise_in_uploadModified attribute Operation ctx.modelInject noise to parameters