from federatedscope.core.configs.config import CN from federatedscope.core.configs.yacs_config import Argument from federatedscope.register import register_config def extend_fl_algo_cfg(cfg): # ---------------------------------------------------------------------- # # fedopt related options, a general fl algorithm # ---------------------------------------------------------------------- # cfg.fedopt = CN() cfg.fedopt.use = False cfg.fedopt.optimizer = CN(new_allowed=True) cfg.fedopt.optimizer.type = Argument( 'SGD', description="optimizer type for FedOPT") cfg.fedopt.optimizer.lr = Argument( 0.01, description="learning rate for FedOPT optimizer") cfg.fedopt.annealing = False cfg.fedopt.annealing_step_size = 2000 cfg.fedopt.annealing_gamma = 0.5 # ---------------------------------------------------------------------- # # fedprox related options, a general fl algorithm # ---------------------------------------------------------------------- # cfg.fedprox = CN() cfg.fedprox.use = False cfg.fedprox.mu = 0. # ---------------------------------------------------------------------- # # fedswa related options, Stochastic Weight Averaging (SWA) # ---------------------------------------------------------------------- # cfg.fedswa = CN() cfg.fedswa.use = False cfg.fedswa.freq = 10 cfg.fedswa.start_rnd = 30 # ---------------------------------------------------------------------- # # Personalization related options, pFL # ---------------------------------------------------------------------- # cfg.personalization = CN() # client-distinct param names, e.g., ['pre', 'post'] cfg.personalization.local_param = [] cfg.personalization.share_non_trainable_para = False cfg.personalization.local_update_steps = -1 # @regular_weight: # The smaller the regular_weight is, the stronger emphasising on # personalized model # For Ditto, the default value=0.1, the search space is [0.05, 0.1, 0.2, # 1, 2] # For pFedMe, the default value=15 cfg.personalization.regular_weight = 0.1 # @lr: # 1) For pFedME, the personalized learning rate to calculate theta # approximately using K steps # 2) 0.0 indicates use the value according to optimizer.lr in case of # users have not specify a valid lr cfg.personalization.lr = 0.0 cfg.personalization.K = 5 # the local approximation steps for pFedMe cfg.personalization.beta = 1.0 # the average moving parameter for pFedMe # parameters for FedRep: cfg.personalization.lr_feature = 0.1 # learning rate: feature extractors cfg.personalization.lr_linear = 0.1 # learning rate: linear head cfg.personalization.epoch_feature = 1 # training epoch number cfg.personalization.epoch_linear = 2 # training epoch number cfg.personalization.weight_decay = 0.0 # ---------------------------------------------------------------------- # # FedSage+ related options, gfl # ---------------------------------------------------------------------- # cfg.fedsageplus = CN() # Number of nodes generated by the generator cfg.fedsageplus.num_pred = 5 # Hidden layer dimension of generator cfg.fedsageplus.gen_hidden = 128 # Hide graph portion cfg.fedsageplus.hide_portion = 0.5 # Federated training round for generator cfg.fedsageplus.fedgen_epoch = 200 # Local pre-train round for generator cfg.fedsageplus.loc_epoch = 1 # Coefficient for criterion number of missing node cfg.fedsageplus.a = 1.0 # Coefficient for criterion feature cfg.fedsageplus.b = 1.0 # Coefficient for criterion classification cfg.fedsageplus.c = 1.0 # ---------------------------------------------------------------------- # # GCFL+ related options, gfl # ---------------------------------------------------------------------- # cfg.gcflplus = CN() # Bound for mean_norm cfg.gcflplus.EPS_1 = 0.05 # Bound for max_norm cfg.gcflplus.EPS_2 = 0.1 # Length of the gradient sequence cfg.gcflplus.seq_length = 5 # Whether standardized dtw_distances cfg.gcflplus.standardize = False # ---------------------------------------------------------------------- # # FLIT+ related options, gfl # ---------------------------------------------------------------------- # cfg.flitplus = CN() cfg.flitplus.tmpFed = 0.5 # gamma in focal loss (Eq.4) cfg.flitplus.lambdavat = 0.5 # lambda in phi (Eq.10) cfg.flitplus.factor_ema = 0.8 # beta in omega (Eq.12) cfg.flitplus.weightReg = 1.0 # balance lossLocalLabel and lossLocalVAT # --------------- register corresponding check function ---------- cfg.register_cfg_check_fun(assert_fl_algo_cfg) def assert_fl_algo_cfg(cfg): if cfg.personalization.local_update_steps == -1: # By default, use the same step to normal mode cfg.personalization.local_update_steps = \ cfg.train.local_update_steps cfg.personalization.local_update_steps = \ cfg.train.local_update_steps if cfg.personalization.lr <= 0.0: # By default, use the same lr to normal mode cfg.personalization.lr = cfg.train.optimizer.lr if cfg.fedswa.use: assert cfg.fedswa.start_rnd < cfg.federate.total_round_num, \ f'`cfg.fedswa.start_rnd` {cfg.fedswa.start_rnd} must be smaller ' \ f'than `cfg.federate.total_round_num` ' \ f'{cfg.federate.total_round_num}.' register_config("fl_algo", extend_fl_algo_cfg)