41 lines
1.0 KiB
Python
41 lines
1.0 KiB
Python
from federatedscope.register import regularizer_dict
|
|
from federatedscope.core.regularizer.proximal_regularizer import *
|
|
try:
|
|
from torch.nn import Module
|
|
except ImportError:
|
|
Module = object
|
|
|
|
|
|
def get_regularizer(reg_type):
|
|
"""
|
|
This function builds an instance of regularizer to regularize training.
|
|
|
|
Args:
|
|
reg_type: type of scheduler, such as see \
|
|
https://pytorch.org/docs/stable/optim.html for details
|
|
|
|
Returns:
|
|
An instantiated regularizer.
|
|
"""
|
|
if reg_type is None or reg_type == '':
|
|
return DummyRegularizer()
|
|
|
|
for func in regularizer_dict.values():
|
|
regularizer = func(reg_type)
|
|
if regularizer is not None:
|
|
return regularizer()
|
|
|
|
raise NotImplementedError(
|
|
"Regularizer {} is not implemented.".format(reg_type))
|
|
|
|
|
|
class DummyRegularizer(Module):
|
|
"""Dummy regularizer that only returns zero.
|
|
|
|
"""
|
|
def __init__(self):
|
|
super(DummyRegularizer, self).__init__()
|
|
|
|
def forward(self, ctx):
|
|
return 0.
|