27 lines
699 B
Python
27 lines
699 B
Python
import logging
|
|
import os
|
|
from federatedscope.register import register_metric
|
|
from federatedscope.nlp.metric.rouge.utils import test_rouge
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def load_cnndm_metrics(ctx, **kwargs):
|
|
tmp_dir = os.path.join(ctx.cfg.outdir, 'temp')
|
|
rouges = test_rouge(tmp_dir, ctx.pred_path, ctx.tgt_path)
|
|
results = {
|
|
k: v
|
|
for k, v in rouges.items()
|
|
if k in {'rouge_1_f_score', 'rouge_2_f_score', 'rouge_l_f_score'}
|
|
}
|
|
return results
|
|
|
|
|
|
def call_cnndm_metric(types):
|
|
if 'cnndm' in types:
|
|
the_larger_the_better = True
|
|
return 'cnndm', load_cnndm_metrics, the_larger_the_better
|
|
|
|
|
|
register_metric('cnndm', call_cnndm_metric)
|