FS-TFP/federatedscope/nlp/hetero_tasks/metric/newsqa.py

26 lines
928 B
Python

from federatedscope.register import register_metric
from federatedscope.nlp.hetero_tasks.metric.squad import compute_squad_metrics
def load_newsqa_metrics(ctx, **kwargs):
examples = ctx.get('{}_examples'.format(ctx.cur_split))
encoded_inputs = ctx.get('{}_encoded'.format(ctx.cur_split))
results = ctx.newsqa_results
n_best_size = ctx.cfg.model.n_best_size
max_answer_len = ctx.cfg.model.max_answer_len
null_score_diff_threshold = ctx.cfg.model.null_score_diff_threshold
metrics = compute_squad_metrics(examples, encoded_inputs, results,
n_best_size, max_answer_len,
null_score_diff_threshold)
return metrics
def call_newsqa_metric(types):
if 'newsqa' in types:
the_larger_the_better = True
return 'newsqa', load_newsqa_metrics, the_larger_the_better
register_metric('newsqa', call_newsqa_metric)