221 lines
9.2 KiB
Python
221 lines
9.2 KiB
Python
import copy
|
|
|
|
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
|
|
from federatedscope.core.optimizer import wrap_regularized_optimizer
|
|
from typing import Type
|
|
|
|
|
|
def wrap_pFedMeTrainer(
|
|
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
|
|
"""
|
|
Build a `pFedMeTrainer` with a plug-in manner, by registering new
|
|
functions into specific `BaseTrainer`
|
|
|
|
The pFedMe implementation, "Personalized Federated Learning with Moreau
|
|
Envelopes (NeurIPS 2020)"
|
|
is based on the Algorithm 1 in their paper and official codes:
|
|
https://github.com/CharlieDinh/pFedMe
|
|
"""
|
|
|
|
# ---------------- attribute-level plug-in -----------------------
|
|
init_pFedMe_ctx(base_trainer)
|
|
|
|
# ---------------- action-level plug-in -----------------------
|
|
base_trainer.register_hook_in_train(
|
|
new_hook=_hook_on_fit_start_set_local_para_tmp,
|
|
trigger="on_fit_start",
|
|
insert_pos=-1)
|
|
base_trainer.register_hook_in_train(
|
|
new_hook=_hook_on_epoch_end_update_local,
|
|
trigger="on_epoch_end",
|
|
insert_pos=-1)
|
|
base_trainer.register_hook_in_train(new_hook=_hook_on_fit_end_update_local,
|
|
trigger="on_fit_end",
|
|
insert_pos=-1)
|
|
|
|
base_trainer.register_hook_in_train(new_hook=_hook_on_batch_end_flop_count,
|
|
trigger="on_batch_end",
|
|
insert_pos=-1)
|
|
base_trainer.register_hook_in_train(new_hook=_hook_on_epoch_end_flop_count,
|
|
trigger="on_epoch_end",
|
|
insert_pos=-1)
|
|
|
|
# for "on_batch_start" trigger: replace the original hooks into new ones
|
|
# of pFedMe
|
|
# 1) cache the original hooks for "on_batch_start"
|
|
base_trainer.ctx.original_hook_on_batch_start_train = \
|
|
base_trainer.hooks_in_train["on_batch_start"]
|
|
# 2) replace the original hooks for "on_batch_start"
|
|
base_trainer.replace_hook_in_train(
|
|
new_hook=_hook_on_batch_start_init_pfedme,
|
|
target_trigger="on_batch_start",
|
|
target_hook_name=None)
|
|
|
|
return base_trainer
|
|
|
|
|
|
def init_pFedMe_ctx(base_trainer):
|
|
"""
|
|
init necessary attributes used in pFedMe,
|
|
some new attributes will be with prefix `pFedMe` optimizer to avoid
|
|
namespace pollution
|
|
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.optimizer_for_global_model`` False
|
|
================================== ===========================
|
|
|
|
"""
|
|
ctx = base_trainer.ctx
|
|
cfg = base_trainer.cfg
|
|
|
|
# pFedMe finds approximate model with K steps using the same data batch
|
|
# the complexity of each pFedMe client is K times the one of FedAvg
|
|
ctx.pFedMe_K = cfg.personalization.K
|
|
ctx.num_train_epoch *= ctx.pFedMe_K
|
|
ctx.pFedMe_approx_fit_counter = 0
|
|
|
|
# the local_model_tmp is used to be the referenced parameter when
|
|
# finding the approximate \theta in paper
|
|
# will be copied from model every run_routine
|
|
ctx.pFedMe_local_model_tmp = None
|
|
|
|
|
|
def _hook_on_fit_start_set_local_para_tmp(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.optimizer`` Wrapped by \
|
|
``wrap_regularized_optimizer`` and set compared parameter group
|
|
``ctx.pFedMe_outer_lr`` Initialize to \
|
|
``ctx.cfg.train.optimizer.lr``
|
|
``ctx.pFedMe_local_model_tmp`` Copy from ``ctx.model``
|
|
================================== ===========================
|
|
"""
|
|
# the optimizer used in pFedMe is based on Moreau Envelopes regularization
|
|
# besides, there are two distinct lr for the approximate model and base
|
|
# model
|
|
ctx.optimizer = wrap_regularized_optimizer(
|
|
ctx.optimizer, ctx.cfg.personalization.regular_weight)
|
|
for g in ctx.optimizer.param_groups:
|
|
g['lr'] = ctx.cfg.personalization.lr
|
|
ctx.pFedMe_outer_lr = ctx.cfg.train.optimizer.lr
|
|
|
|
ctx.pFedMe_local_model_tmp = copy.deepcopy(ctx.model)
|
|
# set the compared model data, then the optimizer will find approximate
|
|
# model using trainer.cfg.personalization.lr
|
|
compared_global_model_para = [{
|
|
"params": list(ctx.pFedMe_local_model_tmp.parameters())
|
|
}]
|
|
ctx.optimizer.set_compared_para_group(compared_global_model_para)
|
|
|
|
|
|
def _hook_on_batch_start_init_pfedme(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.data_batch_cache`` Copy from ``ctx.data_batch``
|
|
``ctx.pFedMe_approx_fit_counter`` Count to refresh data every K step
|
|
================================== ===========================
|
|
"""
|
|
# refresh data every K step
|
|
if ctx.pFedMe_approx_fit_counter == 0:
|
|
if ctx.cur_mode == "train":
|
|
for hook in ctx.original_hook_on_batch_start_train:
|
|
hook(ctx)
|
|
else:
|
|
for hook in ctx.original_hook_on_batch_start_eval:
|
|
hook(ctx)
|
|
ctx.data_batch_cache = copy.deepcopy(ctx.data_batch)
|
|
else:
|
|
# reuse the data_cache since the original hook `_hook_on_batch_end`
|
|
# will clean `data_batch`
|
|
ctx.data_batch = copy.deepcopy(ctx.data_batch_cache)
|
|
ctx.pFedMe_approx_fit_counter = (ctx.pFedMe_approx_fit_counter +
|
|
1) % ctx.pFedMe_K
|
|
|
|
|
|
def _hook_on_batch_end_flop_count(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.monitor`` Monitor total flops
|
|
================================== ===========================
|
|
"""
|
|
# besides the normal forward flops, pFedMe introduces
|
|
# 1) the regularization adds the cost of number of model parameters
|
|
ctx.monitor.total_flops += ctx.monitor.total_model_size / 2
|
|
|
|
|
|
def _hook_on_epoch_end_flop_count(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.monitor`` Monitor total flops
|
|
================================== ===========================
|
|
"""
|
|
# due to the local weight updating
|
|
ctx.monitor.total_flops += ctx.monitor.total_model_size / 2
|
|
|
|
|
|
def _hook_on_epoch_end_update_local(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.model`` Update parameters by \
|
|
``ctx.pFedMe_local_model_tmp``
|
|
``ctx.optimizer`` Set compared parameter group
|
|
================================== ===========================
|
|
"""
|
|
# update local weight after finding approximate theta
|
|
for client_param, local_para_tmp in zip(
|
|
ctx.model.parameters(), ctx.pFedMe_local_model_tmp.parameters()):
|
|
local_para_tmp.data = local_para_tmp.data - \
|
|
ctx.optimizer.regular_weight * \
|
|
ctx.pFedMe_outer_lr * (local_para_tmp.data -
|
|
client_param.data)
|
|
|
|
# set the compared model data, then the optimizer will find approximate
|
|
# model using trainer.cfg.personalization.lr
|
|
compared_global_model_para = [{
|
|
"params": list(ctx.pFedMe_local_model_tmp.parameters())
|
|
}]
|
|
ctx.optimizer.set_compared_para_group(compared_global_model_para)
|
|
|
|
|
|
def _hook_on_fit_end_update_local(ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.model`` Update parameters by
|
|
``ctx.pFedMe_local_model_tmp``
|
|
``ctx.pFedMe_local_model_tmp`` Delete
|
|
================================== ===========================
|
|
"""
|
|
for param, local_para_tmp in zip(ctx.model.parameters(),
|
|
ctx.pFedMe_local_model_tmp.parameters()):
|
|
param.data = local_para_tmp.data
|
|
|
|
del ctx.pFedMe_local_model_tmp
|