61 lines
2.7 KiB
Python
61 lines
2.7 KiB
Python
import logging
|
|
import federatedscope.register as register
|
|
from federatedscope.nlp.hetero_tasks.metric import *
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
from federatedscope.contrib.metrics import *
|
|
except ImportError as error:
|
|
logger.warning(
|
|
f'{error} in `federatedscope.contrib.metrics`, some modules are not '
|
|
f'available.')
|
|
|
|
|
|
def get_metric(types):
|
|
"""
|
|
This function returns a dict, where the key is metric name, and value is \
|
|
the function of how to calculate the metric and a bool to indicate the \
|
|
metric is larger the better.
|
|
|
|
Args:
|
|
types: list of metric names
|
|
|
|
Returns:
|
|
A metric calculator dict, such as \
|
|
``{'loss': (eval_loss, False), 'acc': (eval_acc, True), ...}``
|
|
|
|
Note:
|
|
The key-value pairs of built-in metric and related funcs and \
|
|
``the_larger_the_better`` sign is shown below:
|
|
================= ============================================= =====
|
|
Metric name Source \
|
|
The larger the better
|
|
================= ============================================= =====
|
|
``loss`` ``monitors.metric_calculator.eval_loss`` False
|
|
``avg_loss`` ``monitors.metric_calculator.eval_avg_loss`` False
|
|
``total`` ``monitors.metric_calculator.eval_total`` False
|
|
``correct`` ``monitors.metric_calculator.eval_correct`` True
|
|
``acc`` ``monitors.metric_calculator.eval_acc`` True
|
|
``ap`` ``monitors.metric_calculator.eval_ap`` True
|
|
``f1`` ``monitors.metric_calculator.eval_f1_score`` True
|
|
``roc_auc`` ``monitors.metric_calculator.eval_roc_auc`` True
|
|
``rmse`` ``monitors.metric_calculator.eval_rmse`` False
|
|
``mse`` ``monitors.metric_calculator.eval_mse`` False
|
|
``loss_regular`` ``monitors.metric_calculator.eval_regular`` False
|
|
``imp_ratio`` ``monitors.metric_calculator.eval_imp_ratio`` True
|
|
``std`` ``None`` False
|
|
``hits@{n}`` ``monitors.metric_calculator.eval_hits`` True
|
|
================= ============================================= =====
|
|
"""
|
|
metrics = dict()
|
|
for func in register.metric_dict.values():
|
|
res = func(types)
|
|
if res is not None:
|
|
name, metric, the_larger_the_better = res
|
|
metrics[name] = metric, the_larger_the_better
|
|
for key in types:
|
|
if key not in metrics.keys():
|
|
logger.warning(f'eval.metrics `{key}` method not found!')
|
|
return metrics
|