319 lines
13 KiB
Python
319 lines
13 KiB
Python
import copy
|
|
import logging
|
|
|
|
import torch
|
|
|
|
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
|
|
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
|
|
from federatedscope.core.optimizer import wrap_regularized_optimizer
|
|
from federatedscope.core.trainers.utils import calculate_batch_epoch_num
|
|
from typing import Type
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEBUG_DITTO = False
|
|
|
|
|
|
def wrap_DittoTrainer(
|
|
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
|
|
"""
|
|
Build a `DittoTrainer` with a plug-in manner, by registering new
|
|
functions into specific `BaseTrainer`
|
|
|
|
The Ditto implementation, "Ditto: Fair and Robust Federated Learning
|
|
Through Personalization. (ICML2021)"
|
|
based on the Algorithm 2 in their paper and official codes:
|
|
https://github.com/litian96/ditto
|
|
"""
|
|
|
|
# ---------------- attribute-level plug-in -----------------------
|
|
init_Ditto_ctx(base_trainer)
|
|
|
|
# ---------------- action-level plug-in -----------------------
|
|
base_trainer.register_hook_in_train(new_hook=_hook_on_fit_start_clean,
|
|
trigger='on_fit_start',
|
|
insert_pos=-1)
|
|
base_trainer.register_hook_in_train(
|
|
new_hook=_hook_on_fit_start_set_regularized_para,
|
|
trigger="on_fit_start",
|
|
insert_pos=0)
|
|
base_trainer.register_hook_in_train(
|
|
new_hook=_hook_on_batch_start_switch_model,
|
|
trigger="on_batch_start",
|
|
insert_pos=0)
|
|
base_trainer.register_hook_in_train(
|
|
new_hook=_hook_on_batch_forward_cnt_num,
|
|
trigger="on_batch_forward",
|
|
insert_pos=-1)
|
|
base_trainer.register_hook_in_train(new_hook=_hook_on_batch_end_flop_count,
|
|
trigger="on_batch_end",
|
|
insert_pos=-1)
|
|
base_trainer.register_hook_in_train(new_hook=_hook_on_fit_end_calibrate,
|
|
trigger='on_fit_end',
|
|
insert_pos=-1)
|
|
# evaluation is based on the local personalized model
|
|
base_trainer.register_hook_in_eval(
|
|
new_hook=_hook_on_fit_start_switch_local_model,
|
|
trigger="on_fit_start",
|
|
insert_pos=0)
|
|
base_trainer.register_hook_in_eval(
|
|
new_hook=_hook_on_fit_end_switch_global_model,
|
|
trigger="on_fit_end",
|
|
insert_pos=-1)
|
|
|
|
base_trainer.register_hook_in_train(new_hook=_hook_on_fit_end_free_cuda,
|
|
trigger="on_fit_end",
|
|
insert_pos=-1)
|
|
base_trainer.register_hook_in_eval(new_hook=_hook_on_fit_end_free_cuda,
|
|
trigger="on_fit_end",
|
|
insert_pos=-1)
|
|
|
|
return base_trainer
|
|
|
|
|
|
def init_Ditto_ctx(base_trainer):
|
|
"""
|
|
init necessary attributes used in Ditto,
|
|
`global_model` acts as the shared global model in FedAvg;
|
|
`local_model` acts as personalized model will be optimized with
|
|
regularization based on weights of `global_model`
|
|
|
|
"""
|
|
ctx = base_trainer.ctx
|
|
cfg = base_trainer.cfg
|
|
|
|
ctx.global_model = copy.deepcopy(ctx.model)
|
|
ctx.local_model = copy.deepcopy(ctx.model) # the personalized model
|
|
ctx.models = [ctx.local_model, ctx.global_model]
|
|
|
|
ctx.model = ctx.global_model
|
|
ctx.use_local_model_current = False
|
|
|
|
ctx.num_samples_local_model_train = 0
|
|
|
|
# track the batch_num, epoch_num, for local & global model respectively
|
|
cfg_p_local_update_steps = cfg.personalization.local_update_steps
|
|
ctx.num_train_batch_for_local_model, \
|
|
ctx.num_train_batch_last_epoch_for_local_model, \
|
|
ctx.num_train_epoch_for_local_model, \
|
|
ctx.num_total_train_batch = \
|
|
calculate_batch_epoch_num(cfg_p_local_update_steps,
|
|
cfg.train.batch_or_epoch,
|
|
ctx.num_train_data,
|
|
cfg.dataloader.batch_size,
|
|
cfg.dataloader.drop_last)
|
|
|
|
# In the first
|
|
# 1. `num_train_batch` and `num_train_batch_last_epoch`
|
|
# (batch_or_epoch == 'batch' case) or
|
|
# 2. `num_train_epoch`,
|
|
# (batch_or_epoch == 'epoch' case)
|
|
# we will manipulate local models, and manipulate global model in the
|
|
# remaining steps
|
|
if cfg.train.batch_or_epoch == 'batch':
|
|
ctx.num_train_batch += ctx.num_train_batch_for_local_model
|
|
ctx.num_train_batch_last_epoch += \
|
|
ctx.num_train_batch_last_epoch_for_local_model
|
|
else:
|
|
ctx.num_train_epoch += ctx.num_train_epoch_for_local_model
|
|
|
|
|
|
def _hook_on_fit_start_set_regularized_para(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.global_model`` Move to ``ctx.device`` and set \
|
|
to ``train`` mode
|
|
``ctx.local_model`` Move to ``ctx.device`` and set \
|
|
to ``train`` mode
|
|
``ctx.optimizer_for_global_model`` Initialize by ``ctx.cfg`` and \
|
|
wrapped by ``wrap_regularized_optimizer``
|
|
``ctx.optimizer_for_local_model`` Initialize by ``ctx.cfg`` and \
|
|
set compared parameter group
|
|
================================== ===========================
|
|
"""
|
|
# set the compared model data for local personalized model
|
|
ctx.global_model.to(ctx.device)
|
|
ctx.local_model.to(ctx.device)
|
|
ctx.global_model.train()
|
|
ctx.local_model.train()
|
|
compared_global_model_para = [{
|
|
"params": list(ctx.global_model.parameters())
|
|
}]
|
|
|
|
ctx.optimizer_for_global_model = get_optimizer(ctx.global_model,
|
|
**ctx.cfg.train.optimizer)
|
|
ctx.optimizer_for_local_model = get_optimizer(ctx.local_model,
|
|
**ctx.cfg.train.optimizer)
|
|
|
|
ctx.optimizer_for_local_model = wrap_regularized_optimizer(
|
|
ctx.optimizer_for_local_model, ctx.cfg.personalization.regular_weight)
|
|
|
|
ctx.optimizer_for_local_model.set_compared_para_group(
|
|
compared_global_model_para)
|
|
|
|
|
|
def _hook_on_fit_start_clean(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.optimizer`` Delete
|
|
``ctx.num_..._local_model_train`` Initialize to 0
|
|
================================== ===========================
|
|
"""
|
|
# remove the unnecessary optimizer
|
|
del ctx.optimizer
|
|
ctx.num_samples_local_model_train = 0
|
|
|
|
|
|
def _hook_on_fit_end_calibrate(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.num_samples`` Minus \
|
|
``ctx.num_samples_local_model_train``
|
|
``ctx.eval_metrics`` Record ``train_total`` and \
|
|
``train_total_local_model``
|
|
================================== ===========================
|
|
"""
|
|
# make the num_samples_train only related to the global model.
|
|
# (num_samples_train will be used in aggregation process)
|
|
ctx.num_samples -= ctx.num_samples_local_model_train
|
|
ctx.eval_metrics['train_total'] = ctx.num_samples
|
|
ctx.eval_metrics['train_total_local_model'] = \
|
|
ctx.num_samples_local_model_train
|
|
|
|
|
|
def _hook_on_batch_end_flop_count(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.monitor`` Monitor total flops
|
|
================================== ===========================
|
|
"""
|
|
# besides the normal forward flops, the regularization adds the cost of
|
|
# number of model parameters
|
|
ctx.monitor.total_flops += ctx.monitor.total_model_size / 2
|
|
|
|
|
|
def _hook_on_batch_forward_cnt_num(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.num_..._local_model_train`` Add `ctx.batch_size`
|
|
================================== ===========================
|
|
"""
|
|
if ctx.use_local_model_current:
|
|
ctx.num_samples_local_model_train += ctx.batch_size
|
|
|
|
|
|
def _hook_on_batch_start_switch_model(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.use_local_model_current`` Set to ``True`` or ``False``
|
|
``ctx.model`` Set to ``ctx.local_model`` or \
|
|
``ctx.global_model``
|
|
``ctx.optimizer`` Set to \
|
|
``ctx.optimizer_for_local_model`` or ``ctx.optimizer_for_global_model``
|
|
================================== ===========================
|
|
"""
|
|
if ctx.cfg.train.batch_or_epoch == 'batch':
|
|
if ctx.cur_epoch_i == (ctx.num_train_epoch - 1):
|
|
ctx.use_local_model_current = \
|
|
ctx.cur_batch_i < \
|
|
ctx.num_train_batch_last_epoch_for_local_model
|
|
else:
|
|
ctx.use_local_model_current = \
|
|
ctx.cur_batch_i < ctx.num_train_batch_for_local_model
|
|
else:
|
|
ctx.use_local_model_current = \
|
|
ctx.cur_epoch_i < ctx.num_train_epoch_for_local_model
|
|
|
|
if DEBUG_DITTO:
|
|
logger.info("====================================================")
|
|
logger.info(f"cur_epoch_i: {ctx.cur_epoch_i}")
|
|
logger.info(f"num_train_epoch: {ctx.num_train_epoch}")
|
|
logger.info(f"cur_batch_i: {ctx.cur_batch_i}")
|
|
logger.info(f"num_train_batch: {ctx.num_train_batch}")
|
|
logger.info(f"num_train_batch_for_local_model: "
|
|
f"{ctx.num_train_batch_for_local_model}")
|
|
logger.info(f"num_train_epoch_for_local_model: "
|
|
f"{ctx.num_train_epoch_for_local_model}")
|
|
logger.info(f"use_local_model: {ctx.use_local_model_current}")
|
|
|
|
if ctx.use_local_model_current:
|
|
ctx.model = ctx.local_model
|
|
ctx.optimizer = ctx.optimizer_for_local_model
|
|
else:
|
|
ctx.model = ctx.global_model
|
|
ctx.optimizer = ctx.optimizer_for_global_model
|
|
|
|
|
|
# Note that Ditto only updates the para of global_model received from other
|
|
# FL participants, and in the remaining steps, ctx.model has been =
|
|
# ctx.global_model, thus we do not need register the following hook
|
|
# def hook_on_fit_end_link_global_model(ctx):
|
|
# ctx.model = ctx.global_model
|
|
|
|
|
|
def _hook_on_fit_start_switch_local_model(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.model`` Set to ``ctx.local_model`` and \
|
|
set to ``eval`` mode
|
|
================================== ===========================
|
|
"""
|
|
ctx.model = ctx.local_model
|
|
ctx.model.eval()
|
|
|
|
|
|
def _hook_on_fit_end_switch_global_model(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.model `` Set to ``ctx.global_model``
|
|
================================== ===========================
|
|
"""
|
|
ctx.model = ctx.global_model
|
|
|
|
|
|
def _hook_on_fit_end_free_cuda(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.global_model`` Move to ``cpu``
|
|
``ctx.locol_model`` Move to ``cpu``
|
|
================================== ===========================
|
|
"""
|
|
ctx.global_model.to(torch.device("cpu"))
|
|
ctx.local_model.to(torch.device("cpu"))
|