143 lines
5.5 KiB
Python
143 lines
5.5 KiB
Python
from federatedscope.core.configs.config import CN
|
||
from federatedscope.core.configs.yacs_config import Argument
|
||
from federatedscope.register import register_config
|
||
|
||
|
||
def extend_fl_algo_cfg(cfg):
|
||
# ---------------------------------------------------------------------- #
|
||
# fedopt related options, a general fl algorithm
|
||
# ---------------------------------------------------------------------- #
|
||
cfg.fedopt = CN()
|
||
|
||
cfg.fedopt.use = False
|
||
|
||
cfg.fedopt.optimizer = CN(new_allowed=True)
|
||
cfg.fedopt.optimizer.type = Argument(
|
||
'SGD', description="optimizer type for FedOPT")
|
||
cfg.fedopt.optimizer.lr = Argument(
|
||
0.01, description="learning rate for FedOPT optimizer")
|
||
cfg.fedopt.annealing = False
|
||
cfg.fedopt.annealing_step_size = 2000
|
||
cfg.fedopt.annealing_gamma = 0.5
|
||
|
||
# ---------------------------------------------------------------------- #
|
||
# fedprox related options, a general fl algorithm
|
||
# ---------------------------------------------------------------------- #
|
||
cfg.fedprox = CN()
|
||
|
||
cfg.fedprox.use = False
|
||
cfg.fedprox.mu = 0.
|
||
|
||
# ---------------------------------------------------------------------- #
|
||
# fedswa related options, Stochastic Weight Averaging (SWA)
|
||
# ---------------------------------------------------------------------- #
|
||
cfg.fedswa = CN()
|
||
cfg.fedswa.use = False
|
||
cfg.fedswa.freq = 10
|
||
cfg.fedswa.start_rnd = 30
|
||
|
||
# ---------------------------------------------------------------------- #
|
||
# Personalization related options, pFL
|
||
# ---------------------------------------------------------------------- #
|
||
cfg.personalization = CN()
|
||
|
||
# client-distinct param names, e.g., ['pre', 'post']
|
||
cfg.personalization.local_param = []
|
||
cfg.personalization.share_non_trainable_para = False
|
||
cfg.personalization.local_update_steps = -1
|
||
# @regular_weight:
|
||
# The smaller the regular_weight is, the stronger emphasising on
|
||
# personalized model
|
||
# For Ditto, the default value=0.1, the search space is [0.05, 0.1, 0.2,
|
||
# 1, 2]
|
||
# For pFedMe, the default value=15
|
||
cfg.personalization.regular_weight = 0.1
|
||
|
||
# @lr:
|
||
# 1) For pFedME, the personalized learning rate to calculate theta
|
||
# approximately using K steps
|
||
# 2) 0.0 indicates use the value according to optimizer.lr in case of
|
||
# users have not specify a valid lr
|
||
cfg.personalization.lr = 0.0
|
||
|
||
cfg.personalization.K = 5 # the local approximation steps for pFedMe
|
||
cfg.personalization.beta = 1.0 # the average moving parameter for pFedMe
|
||
|
||
# parameters for FedRep:
|
||
cfg.personalization.lr_feature = 0.1 # learning rate: feature extractors
|
||
cfg.personalization.lr_linear = 0.1 # learning rate: linear head
|
||
cfg.personalization.epoch_feature = 1 # training epoch number
|
||
cfg.personalization.epoch_linear = 2 # training epoch number
|
||
cfg.personalization.weight_decay = 0.0
|
||
|
||
# ---------------------------------------------------------------------- #
|
||
# FedSage+ related options, gfl
|
||
# ---------------------------------------------------------------------- #
|
||
cfg.fedsageplus = CN()
|
||
|
||
# Number of nodes generated by the generator
|
||
cfg.fedsageplus.num_pred = 5
|
||
# Hidden layer dimension of generator
|
||
cfg.fedsageplus.gen_hidden = 128
|
||
# Hide graph portion
|
||
cfg.fedsageplus.hide_portion = 0.5
|
||
# Federated training round for generator
|
||
cfg.fedsageplus.fedgen_epoch = 200
|
||
# Local pre-train round for generator
|
||
cfg.fedsageplus.loc_epoch = 1
|
||
# Coefficient for criterion number of missing node
|
||
cfg.fedsageplus.a = 1.0
|
||
# Coefficient for criterion feature
|
||
cfg.fedsageplus.b = 1.0
|
||
# Coefficient for criterion classification
|
||
cfg.fedsageplus.c = 1.0
|
||
|
||
# ---------------------------------------------------------------------- #
|
||
# GCFL+ related options, gfl
|
||
# ---------------------------------------------------------------------- #
|
||
cfg.gcflplus = CN()
|
||
|
||
# Bound for mean_norm
|
||
cfg.gcflplus.EPS_1 = 0.05
|
||
# Bound for max_norm
|
||
cfg.gcflplus.EPS_2 = 0.1
|
||
# Length of the gradient sequence
|
||
cfg.gcflplus.seq_length = 5
|
||
# Whether standardized dtw_distances
|
||
cfg.gcflplus.standardize = False
|
||
|
||
# ---------------------------------------------------------------------- #
|
||
# FLIT+ related options, gfl
|
||
# ---------------------------------------------------------------------- #
|
||
cfg.flitplus = CN()
|
||
|
||
cfg.flitplus.tmpFed = 0.5 # gamma in focal loss (Eq.4)
|
||
cfg.flitplus.lambdavat = 0.5 # lambda in phi (Eq.10)
|
||
cfg.flitplus.factor_ema = 0.8 # beta in omega (Eq.12)
|
||
cfg.flitplus.weightReg = 1.0 # balance lossLocalLabel and lossLocalVAT
|
||
|
||
# --------------- register corresponding check function ----------
|
||
cfg.register_cfg_check_fun(assert_fl_algo_cfg)
|
||
|
||
|
||
def assert_fl_algo_cfg(cfg):
|
||
if cfg.personalization.local_update_steps == -1:
|
||
# By default, use the same step to normal mode
|
||
cfg.personalization.local_update_steps = \
|
||
cfg.train.local_update_steps
|
||
cfg.personalization.local_update_steps = \
|
||
cfg.train.local_update_steps
|
||
|
||
if cfg.personalization.lr <= 0.0:
|
||
# By default, use the same lr to normal mode
|
||
cfg.personalization.lr = cfg.train.optimizer.lr
|
||
|
||
if cfg.fedswa.use:
|
||
assert cfg.fedswa.start_rnd < cfg.federate.total_round_num, \
|
||
f'`cfg.fedswa.start_rnd` {cfg.fedswa.start_rnd} must be smaller ' \
|
||
f'than `cfg.federate.total_round_num` ' \
|
||
f'{cfg.federate.total_round_num}.'
|
||
|
||
|
||
register_config("fl_algo", extend_fl_algo_cfg)
|