91 lines
2.7 KiB
Python
91 lines
2.7 KiB
Python
import copy
|
|
import logging
|
|
import federatedscope.register as register
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
import torch
|
|
except ImportError:
|
|
torch = None
|
|
|
|
try:
|
|
from federatedscope.contrib.scheduler import *
|
|
except ImportError as error:
|
|
logger.warning(
|
|
f'{error} in `federatedscope.contrib.scheduler`, some modules are not '
|
|
f'available.')
|
|
|
|
|
|
def get_scheduler(optimizer, type, **kwargs):
|
|
"""
|
|
This function builds an instance of scheduler.
|
|
|
|
Args:
|
|
optimizer: optimizer to be scheduled
|
|
type: type of scheduler
|
|
**kwargs: kwargs dict
|
|
|
|
Returns:
|
|
An instantiated scheduler.
|
|
|
|
Note:
|
|
Please follow ``contrib.scheduler.example`` to implement your own \
|
|
scheduler.
|
|
"""
|
|
# in case of users have not called the cfg.freeze()
|
|
tmp_kwargs = copy.deepcopy(kwargs)
|
|
if '__help_info__' in tmp_kwargs:
|
|
del tmp_kwargs['__help_info__']
|
|
if '__cfg_check_funcs__' in tmp_kwargs:
|
|
del tmp_kwargs['__cfg_check_funcs__']
|
|
if 'is_ready_for_run' in tmp_kwargs:
|
|
del tmp_kwargs['is_ready_for_run']
|
|
if 'warmup_ratio' in tmp_kwargs:
|
|
del tmp_kwargs['warmup_ratio']
|
|
if 'warmup_steps' in tmp_kwargs:
|
|
warmup_steps = tmp_kwargs['warmup_steps']
|
|
del tmp_kwargs['warmup_steps']
|
|
if 'total_steps' in tmp_kwargs:
|
|
total_steps = tmp_kwargs['total_steps']
|
|
del tmp_kwargs['total_steps']
|
|
|
|
for func in register.scheduler_dict.values():
|
|
scheduler = func(optimizer, type, **tmp_kwargs)
|
|
if scheduler is not None:
|
|
return scheduler
|
|
|
|
if type == 'warmup_step':
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
def lr_lambda(cur_step):
|
|
if cur_step < warmup_steps:
|
|
return float(cur_step) / float(max(1, warmup_steps))
|
|
return max(
|
|
0.0,
|
|
float(total_steps - cur_step) /
|
|
float(max(1, total_steps - warmup_steps)))
|
|
|
|
return LambdaLR(optimizer, lr_lambda)
|
|
elif type == 'warmup_noam':
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
def lr_lambda(cur_step):
|
|
return min(
|
|
max(1, cur_step)**(-0.5),
|
|
max(1, cur_step) * warmup_steps**(-1.5))
|
|
|
|
return LambdaLR(optimizer, lr_lambda)
|
|
|
|
if torch is None or type == '':
|
|
return None
|
|
if isinstance(type, str):
|
|
if hasattr(torch.optim.lr_scheduler, type):
|
|
return getattr(torch.optim.lr_scheduler, type)(optimizer,
|
|
**tmp_kwargs)
|
|
else:
|
|
raise NotImplementedError(
|
|
'Scheduler {} not implement'.format(type))
|
|
else:
|
|
raise TypeError()
|