from federatedscope.register import register_metric from federatedscope.nlp.hetero_tasks.metric.squad import compute_squad_metrics def load_newsqa_metrics(ctx, **kwargs): examples = ctx.get('{}_examples'.format(ctx.cur_split)) encoded_inputs = ctx.get('{}_encoded'.format(ctx.cur_split)) results = ctx.newsqa_results n_best_size = ctx.cfg.model.n_best_size max_answer_len = ctx.cfg.model.max_answer_len null_score_diff_threshold = ctx.cfg.model.null_score_diff_threshold metrics = compute_squad_metrics(examples, encoded_inputs, results, n_best_size, max_answer_len, null_score_diff_threshold) return metrics def call_newsqa_metric(types): if 'newsqa' in types: the_larger_the_better = True return 'newsqa', load_newsqa_metrics, the_larger_the_better register_metric('newsqa', call_newsqa_metric)