FS-TFP/federatedscope/core/trainers/trainer_FedRep.py

100 lines
3.2 KiB
Python

import copy
import torch
import logging
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
from typing import Type
logger = logging.getLogger(__name__)
def wrap_FedRepTrainer(
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
# ---------------------------------------------------------------------- #
# FedRep method:
# https://arxiv.org/abs/2102.07078
# First training linear classifier and then feature extractor
# Linear classifier: local_param; feature extractor: global_param
# ---------------------------------------------------------------------- #
init_FedRep_ctx(base_trainer)
base_trainer.register_hook_in_train(new_hook=hook_on_fit_start_fedrep,
trigger="on_fit_start",
insert_pos=-1)
base_trainer.register_hook_in_train(new_hook=hook_on_epoch_start_fedrep,
trigger="on_epoch_start",
insert_pos=-1)
return base_trainer
def init_FedRep_ctx(base_trainer):
ctx = base_trainer.ctx
cfg = base_trainer.cfg
ctx.epoch_feature = cfg.personalization.epoch_feature
ctx.epoch_linear = cfg.personalization.epoch_linear
ctx.num_train_epoch = ctx.epoch_feature + ctx.epoch_linear
ctx.epoch_number = 0
ctx.lr_feature = cfg.personalization.lr_feature
ctx.lr_linear = cfg.personalization.lr_linear
ctx.weight_decay = cfg.personalization.weight_decay
ctx.local_param = cfg.personalization.local_param
ctx.local_update_param = []
ctx.global_update_param = []
for name, param in ctx.model.named_parameters():
if name.split(".")[0] in ctx.local_param:
ctx.local_update_param.append(param)
else:
ctx.global_update_param.append(param)
def hook_on_fit_start_fedrep(ctx):
ctx.num_train_epoch = ctx.epoch_feature + ctx.epoch_linear
ctx.epoch_number = 0
ctx.optimizer_for_feature = torch.optim.SGD(ctx.global_update_param,
lr=ctx.lr_feature,
momentum=0,
weight_decay=ctx.weight_decay)
ctx.optimizer_for_linear = torch.optim.SGD(ctx.local_update_param,
lr=ctx.lr_linear,
momentum=0,
weight_decay=ctx.weight_decay)
for name, param in ctx.model.named_parameters():
if name.split(".")[0] in ctx.local_param:
param.requires_grad = True
else:
param.requires_grad = False
ctx.optimizer = ctx.optimizer_for_linear
def hook_on_epoch_start_fedrep(ctx):
ctx.epoch_number += 1
if ctx.epoch_number == ctx.epoch_linear + 1:
for name, param in ctx.model.named_parameters():
if name.split(".")[0] in ctx.local_param:
param.requires_grad = False
else:
param.requires_grad = True
ctx.optimizer = ctx.optimizer_for_feature