47 lines
1.3 KiB
Python
47 lines
1.3 KiB
Python
import logging
|
|
import federatedscope.register as register
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
from torch import nn
|
|
from federatedscope.nlp.loss import *
|
|
from federatedscope.cl.loss import *
|
|
except ImportError:
|
|
nn = None
|
|
|
|
try:
|
|
from federatedscope.contrib.loss import *
|
|
except ImportError as error:
|
|
logger.warning(
|
|
f'{error} in `federatedscope.contrib.loss`, some modules are not '
|
|
f'available.')
|
|
|
|
|
|
def get_criterion(criterion_type, device):
|
|
"""
|
|
This function builds an instance of loss functions from: \
|
|
"https://pytorch.org/docs/stable/nn.html#loss-functions",
|
|
where the ``criterion_type`` is chosen from.
|
|
|
|
Arguments:
|
|
criterion_type: loss function type
|
|
device: move to device (``cpu`` or ``gpu``)
|
|
|
|
Returns:
|
|
An instance of loss functions.
|
|
"""
|
|
for func in register.criterion_dict.values():
|
|
criterion = func(criterion_type, device)
|
|
if criterion is not None:
|
|
return criterion
|
|
|
|
if isinstance(criterion_type, str):
|
|
if hasattr(nn, criterion_type):
|
|
return getattr(nn, criterion_type)()
|
|
else:
|
|
raise NotImplementedError(
|
|
'Criterion {} not implement'.format(criterion_type))
|
|
else:
|
|
raise TypeError()
|