FS-TFP/federatedscope/core/configs/cfg_training.py

151 lines
5.4 KiB
Python

from federatedscope.core.configs.config import CN
from federatedscope.register import register_config
def extend_training_cfg(cfg):
# ---------------------------------------------------------------------- #
# Trainer related options
# ---------------------------------------------------------------------- #
cfg.trainer = CN()
cfg.trainer.type = 'general'
cfg.trainer.sam = CN()
cfg.trainer.sam.adaptive = False
cfg.trainer.sam.rho = 1.0
cfg.trainer.sam.eta = .0
cfg.trainer.local_entropy = CN()
cfg.trainer.local_entropy.gamma = 0.03
cfg.trainer.local_entropy.inc_factor = 1.0
cfg.trainer.local_entropy.eps = 1e-4
cfg.trainer.local_entropy.alpha = 0.75
# atc (TODO: merge later)
cfg.trainer.disp_freq = 50
cfg.trainer.val_freq = 100000000 # eval freq across batches
cfg.trainer.log_dir = ''
# ---------------------------------------------------------------------- #
# Training related options
# ---------------------------------------------------------------------- #
cfg.train = CN()
cfg.train.local_update_steps = 1
cfg.train.batch_or_epoch = 'batch'
cfg.train.data_para_dids = [] # `torch.nn.DataParallel` devices
cfg.train.optimizer = CN(new_allowed=True)
cfg.train.optimizer.type = 'SGD'
cfg.train.optimizer.lr = 0.1
# trafficflow
cfg.train.loss_func = 'mae'
cfg.train.seed = 10
cfg.train.batch_size = 64
cfg.train.epochs = 300
cfg.train.lr_init = 0.003
cfg.train.weight_decay = 0
cfg.train.lr_decay = False
cfg.train.lr_decay_rate = 0.3
cfg.train.lr_decay_step = [5, 20, 40, 70]
cfg.train.early_stop = True
cfg.train.early_stop_patience = 15
cfg.train.grad_norm = False
cfg.train.max_grad_norm = 5
cfg.train.real_value = True
# you can add new arguments 'aa' by `cfg.train.scheduler.aa = 'bb'`
cfg.train.scheduler = CN(new_allowed=True)
cfg.train.scheduler.type = ''
cfg.train.scheduler.warmup_ratio = 0.0
# ---------------------------------------------------------------------- #
# Finetune related options
# ---------------------------------------------------------------------- #
cfg.finetune = CN()
cfg.finetune.before_eval = False
cfg.finetune.local_update_steps = 1
cfg.finetune.batch_or_epoch = 'epoch'
cfg.finetune.freeze_param = ""
cfg.finetune.optimizer = CN(new_allowed=True)
cfg.finetune.optimizer.type = 'SGD'
cfg.finetune.optimizer.lr = 0.1
cfg.finetune.scheduler = CN(new_allowed=True)
cfg.finetune.scheduler.type = ''
cfg.finetune.scheduler.warmup_ratio = 0.0
# simple-tuning
cfg.finetune.simple_tuning = False # use simple tuning, default: False
cfg.finetune.epoch_linear = 10 # training epoch number, default: 10
cfg.finetune.lr_linear = 0.005 # learning rate for training linear head
cfg.finetune.weight_decay = 0.0
cfg.finetune.local_param = [] # tuning parameters list
# ---------------------------------------------------------------------- #
# Gradient related options
# ---------------------------------------------------------------------- #
cfg.grad = CN()
cfg.grad.grad_clip = -1.0 # negative numbers indicate we do not clip grad
cfg.grad.grad_accum_count = 1
# ---------------------------------------------------------------------- #
# Early stopping related options
# ---------------------------------------------------------------------- #
cfg.early_stop = CN()
# patience (int): How long to wait after last time the monitored metric
# improved.
# Note that the actual_checking_round = patience * cfg.eval.freq
# To disable the early stop, set the early_stop.patience to 0
cfg.early_stop.patience = 5
# delta (float): Minimum change in the monitored metric to indicate an
# improvement.
cfg.early_stop.delta = 0.0
# Early stop when no improve to last `patience` round, in ['mean', 'best']
cfg.early_stop.improve_indicator_mode = 'best'
# TODO:trafficflow
# --------------- register corresponding check function ----------
cfg.register_cfg_check_fun(assert_training_cfg)
def assert_training_cfg(cfg):
if cfg.train.batch_or_epoch not in ['batch', 'epoch']:
raise ValueError(
"Value of 'cfg.train.batch_or_epoch' must be chosen from ["
"'batch', 'epoch'].")
if cfg.finetune.batch_or_epoch not in ['batch', 'epoch']:
raise ValueError(
"Value of 'cfg.finetune.batch_or_epoch' must be chosen from ["
"'batch', 'epoch'].")
# TODO: should not be here?
if cfg.backend not in ['torch', 'tensorflow']:
raise ValueError(
"Value of 'cfg.backend' must be chosen from ['torch', "
"'tensorflow'].")
if cfg.backend == 'tensorflow' and cfg.federate.mode == 'standalone':
raise ValueError(
"We only support run with distribued mode when backend is "
"tensorflow")
if cfg.backend == 'tensorflow' and cfg.use_gpu is True:
raise ValueError(
"We only support run with cpu when backend is tensorflow")
if cfg.finetune.before_eval is False and cfg.finetune.local_update_steps\
<= 0:
raise ValueError(
f"When adopting fine-tuning, please set a valid local fine-tune "
f"steps, got {cfg.finetune.local_update_steps}")
register_config("fl_training", extend_training_cfg)