import logging import federatedscope.register as register logger = logging.getLogger(__name__) try: from torch import nn from federatedscope.nlp.loss import * from federatedscope.cl.loss import * except ImportError: nn = None try: from federatedscope.contrib.loss import * except ImportError as error: logger.warning( f'{error} in `federatedscope.contrib.loss`, some modules are not ' f'available.') def get_criterion(criterion_type, device): """ This function builds an instance of loss functions from: \ "https://pytorch.org/docs/stable/nn.html#loss-functions", where the ``criterion_type`` is chosen from. Arguments: criterion_type: loss function type device: move to device (``cpu`` or ``gpu``) Returns: An instance of loss functions. """ for func in register.criterion_dict.values(): criterion = func(criterion_type, device) if criterion is not None: return criterion if isinstance(criterion_type, str): if hasattr(nn, criterion_type): return getattr(nn, criterion_type)() else: raise NotImplementedError( 'Criterion {} not implement'.format(criterion_type)) else: raise TypeError()