151 lines
5.4 KiB
Python
151 lines
5.4 KiB
Python
from federatedscope.core.configs.config import CN
|
|
from federatedscope.register import register_config
|
|
|
|
|
|
def extend_training_cfg(cfg):
|
|
# ---------------------------------------------------------------------- #
|
|
# Trainer related options
|
|
# ---------------------------------------------------------------------- #
|
|
cfg.trainer = CN()
|
|
|
|
cfg.trainer.type = 'general'
|
|
|
|
cfg.trainer.sam = CN()
|
|
cfg.trainer.sam.adaptive = False
|
|
cfg.trainer.sam.rho = 1.0
|
|
cfg.trainer.sam.eta = .0
|
|
|
|
cfg.trainer.local_entropy = CN()
|
|
cfg.trainer.local_entropy.gamma = 0.03
|
|
cfg.trainer.local_entropy.inc_factor = 1.0
|
|
cfg.trainer.local_entropy.eps = 1e-4
|
|
cfg.trainer.local_entropy.alpha = 0.75
|
|
|
|
# atc (TODO: merge later)
|
|
cfg.trainer.disp_freq = 50
|
|
cfg.trainer.val_freq = 100000000 # eval freq across batches
|
|
cfg.trainer.log_dir = ''
|
|
|
|
# ---------------------------------------------------------------------- #
|
|
# Training related options
|
|
# ---------------------------------------------------------------------- #
|
|
cfg.train = CN()
|
|
|
|
cfg.train.local_update_steps = 1
|
|
cfg.train.batch_or_epoch = 'batch'
|
|
cfg.train.data_para_dids = [] # `torch.nn.DataParallel` devices
|
|
|
|
cfg.train.optimizer = CN(new_allowed=True)
|
|
cfg.train.optimizer.type = 'SGD'
|
|
cfg.train.optimizer.lr = 0.1
|
|
|
|
# trafficflow
|
|
cfg.train.loss_func = 'mae'
|
|
cfg.train.seed = 10
|
|
cfg.train.batch_size = 64
|
|
cfg.train.epochs = 300
|
|
cfg.train.lr_init = 0.003
|
|
cfg.train.weight_decay = 0
|
|
cfg.train.lr_decay = False
|
|
cfg.train.lr_decay_rate = 0.3
|
|
cfg.train.lr_decay_step = [5, 20, 40, 70]
|
|
cfg.train.early_stop = True
|
|
cfg.train.early_stop_patience = 15
|
|
cfg.train.grad_norm = False
|
|
cfg.train.max_grad_norm = 5
|
|
cfg.train.real_value = True
|
|
|
|
|
|
# you can add new arguments 'aa' by `cfg.train.scheduler.aa = 'bb'`
|
|
cfg.train.scheduler = CN(new_allowed=True)
|
|
cfg.train.scheduler.type = ''
|
|
cfg.train.scheduler.warmup_ratio = 0.0
|
|
|
|
# ---------------------------------------------------------------------- #
|
|
# Finetune related options
|
|
# ---------------------------------------------------------------------- #
|
|
cfg.finetune = CN()
|
|
|
|
cfg.finetune.before_eval = False
|
|
cfg.finetune.local_update_steps = 1
|
|
cfg.finetune.batch_or_epoch = 'epoch'
|
|
cfg.finetune.freeze_param = ""
|
|
|
|
cfg.finetune.optimizer = CN(new_allowed=True)
|
|
cfg.finetune.optimizer.type = 'SGD'
|
|
cfg.finetune.optimizer.lr = 0.1
|
|
|
|
cfg.finetune.scheduler = CN(new_allowed=True)
|
|
cfg.finetune.scheduler.type = ''
|
|
cfg.finetune.scheduler.warmup_ratio = 0.0
|
|
|
|
# simple-tuning
|
|
cfg.finetune.simple_tuning = False # use simple tuning, default: False
|
|
cfg.finetune.epoch_linear = 10 # training epoch number, default: 10
|
|
cfg.finetune.lr_linear = 0.005 # learning rate for training linear head
|
|
cfg.finetune.weight_decay = 0.0
|
|
cfg.finetune.local_param = [] # tuning parameters list
|
|
|
|
# ---------------------------------------------------------------------- #
|
|
# Gradient related options
|
|
# ---------------------------------------------------------------------- #
|
|
cfg.grad = CN()
|
|
cfg.grad.grad_clip = -1.0 # negative numbers indicate we do not clip grad
|
|
cfg.grad.grad_accum_count = 1
|
|
|
|
# ---------------------------------------------------------------------- #
|
|
# Early stopping related options
|
|
# ---------------------------------------------------------------------- #
|
|
cfg.early_stop = CN()
|
|
|
|
# patience (int): How long to wait after last time the monitored metric
|
|
# improved.
|
|
# Note that the actual_checking_round = patience * cfg.eval.freq
|
|
# To disable the early stop, set the early_stop.patience to 0
|
|
cfg.early_stop.patience = 5
|
|
# delta (float): Minimum change in the monitored metric to indicate an
|
|
# improvement.
|
|
cfg.early_stop.delta = 0.0
|
|
# Early stop when no improve to last `patience` round, in ['mean', 'best']
|
|
cfg.early_stop.improve_indicator_mode = 'best'
|
|
|
|
# TODO:trafficflow
|
|
|
|
|
|
# --------------- register corresponding check function ----------
|
|
cfg.register_cfg_check_fun(assert_training_cfg)
|
|
|
|
|
|
def assert_training_cfg(cfg):
|
|
if cfg.train.batch_or_epoch not in ['batch', 'epoch']:
|
|
raise ValueError(
|
|
"Value of 'cfg.train.batch_or_epoch' must be chosen from ["
|
|
"'batch', 'epoch'].")
|
|
|
|
if cfg.finetune.batch_or_epoch not in ['batch', 'epoch']:
|
|
raise ValueError(
|
|
"Value of 'cfg.finetune.batch_or_epoch' must be chosen from ["
|
|
"'batch', 'epoch'].")
|
|
|
|
# TODO: should not be here?
|
|
if cfg.backend not in ['torch', 'tensorflow']:
|
|
raise ValueError(
|
|
"Value of 'cfg.backend' must be chosen from ['torch', "
|
|
"'tensorflow'].")
|
|
if cfg.backend == 'tensorflow' and cfg.federate.mode == 'standalone':
|
|
raise ValueError(
|
|
"We only support run with distribued mode when backend is "
|
|
"tensorflow")
|
|
if cfg.backend == 'tensorflow' and cfg.use_gpu is True:
|
|
raise ValueError(
|
|
"We only support run with cpu when backend is tensorflow")
|
|
|
|
if cfg.finetune.before_eval is False and cfg.finetune.local_update_steps\
|
|
<= 0:
|
|
raise ValueError(
|
|
f"When adopting fine-tuning, please set a valid local fine-tune "
|
|
f"steps, got {cfg.finetune.local_update_steps}")
|
|
|
|
|
|
register_config("fl_training", extend_training_cfg)
|