26 lines
928 B
Python
26 lines
928 B
Python
from federatedscope.register import register_metric
|
|
from federatedscope.nlp.hetero_tasks.metric.squad import compute_squad_metrics
|
|
|
|
|
|
def load_newsqa_metrics(ctx, **kwargs):
|
|
examples = ctx.get('{}_examples'.format(ctx.cur_split))
|
|
encoded_inputs = ctx.get('{}_encoded'.format(ctx.cur_split))
|
|
results = ctx.newsqa_results
|
|
n_best_size = ctx.cfg.model.n_best_size
|
|
max_answer_len = ctx.cfg.model.max_answer_len
|
|
null_score_diff_threshold = ctx.cfg.model.null_score_diff_threshold
|
|
|
|
metrics = compute_squad_metrics(examples, encoded_inputs, results,
|
|
n_best_size, max_answer_len,
|
|
null_score_diff_threshold)
|
|
return metrics
|
|
|
|
|
|
def call_newsqa_metric(types):
|
|
if 'newsqa' in types:
|
|
the_larger_the_better = True
|
|
return 'newsqa', load_newsqa_metrics, the_larger_the_better
|
|
|
|
|
|
register_metric('newsqa', call_newsqa_metric)
|