76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
import copy
|
|
import torch
|
|
import logging
|
|
import math
|
|
|
|
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
|
|
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
|
|
|
|
from typing import Type
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def wrap_Simple_tuning_Trainer(
|
|
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
|
|
# ---------------------------------------------------------------------- #
|
|
# Simple_tuning method:
|
|
# https://arxiv.org/abs/2302.01677
|
|
# Only tuning the linear classifier and freeze the feature extractor
|
|
# the key is to reinitialize the linear classifier
|
|
# ---------------------------------------------------------------------- #
|
|
init_Simple_tuning_ctx(base_trainer)
|
|
|
|
base_trainer.register_hook_in_ft(new_hook=hook_on_fit_start_simple_tuning,
|
|
trigger="on_fit_start",
|
|
insert_pos=-1)
|
|
|
|
return base_trainer
|
|
|
|
|
|
def init_Simple_tuning_ctx(base_trainer):
|
|
|
|
ctx = base_trainer.ctx
|
|
cfg = base_trainer.cfg
|
|
|
|
ctx.epoch_linear = cfg.finetune.epoch_linear
|
|
|
|
ctx.num_train_epoch = ctx.epoch_linear
|
|
|
|
ctx.epoch_number = 0
|
|
|
|
ctx.lr_linear = cfg.finetune.lr_linear
|
|
ctx.weight_decay = cfg.finetune.weight_decay
|
|
|
|
ctx.local_param = cfg.finetune.local_param
|
|
|
|
ctx.local_update_param = []
|
|
|
|
for name, param in ctx.model.named_parameters():
|
|
if name.split(".")[0] in ctx.local_param:
|
|
ctx.local_update_param.append(param)
|
|
|
|
|
|
def hook_on_fit_start_simple_tuning(ctx):
|
|
|
|
ctx.num_train_epoch = ctx.epoch_linear
|
|
ctx.epoch_number = 0
|
|
|
|
ctx.optimizer_for_linear = torch.optim.SGD(ctx.local_update_param,
|
|
lr=ctx.lr_linear,
|
|
momentum=0,
|
|
weight_decay=ctx.weight_decay)
|
|
|
|
for name, param in ctx.model.named_parameters():
|
|
if name.split(".")[0] in ctx.local_param:
|
|
if name.split(".")[1] == 'weight':
|
|
stdv = 1. / math.sqrt(param.size(-1))
|
|
param.data.uniform_(-stdv, stdv)
|
|
else:
|
|
param.data.uniform_(-stdv, stdv)
|
|
param.requires_grad = True
|
|
else:
|
|
param.requires_grad = False
|
|
|
|
ctx.optimizer = ctx.optimizer_for_linear
|