""" The implementations are adapted from https://github.com/hugochan/ RL-based-Graph2Seq-for-NQG/blob/master/src/core/evaluation/eval.py and https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py """ from json import encoder from collections import defaultdict from federatedscope.nlp.metric.bleu import Bleu from federatedscope.nlp.metric.meteor import Meteor encoder.FLOAT_REPR = lambda o: format(o, '.4f') class QGEvalCap: def __init__(self, gts, res): self.gts = gts self.res = res def evaluate(self, include_meteor=True, verbose=False): output = {} scorers = [ (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), ] if include_meteor: scorers.append((Meteor(), "METEOR")) for scorer, method in scorers: score, scores = scorer.compute_score(self.gts, self.res) if type(method) == list: for sc, scs, m in zip(score, scores, method): if verbose: print("%s: %0.5f" % (m, sc)) output[m] = sc else: if verbose: print("%s: %0.5f" % (method, score)) output[method] = score return output def eval(out_file, src_file, tgt_file): pairs = [] with open(src_file, 'r', encoding='utf-8') as infile: for line in infile: pair = {} pair['tokenized_sentence'] = line[:-1] pairs.append(pair) with open(tgt_file, 'r', encoding='utf-8') as infile: cnt = 0 for line in infile: pairs[cnt]['tokenized_question'] = line[:-1] cnt += 1 output = [] with open(out_file, 'r', encoding='utf-8') as infile: for line in infile: line = line[:-1] output.append(line) for idx, pair in enumerate(pairs): pair['prediction'] = output[idx] res = defaultdict(lambda: []) gts = defaultdict(lambda: []) for pair in pairs[:]: key = pair['tokenized_sentence'] res[key] = [pair['prediction']] gts[key].append(pair['tokenized_question']) QGEval = QGEvalCap(gts, res) return QGEval.evaluate() if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument("-out", "--out_file", dest="out_file", default="./output/pred.txt", help="output file to compare") parser.add_argument("-src", "--src_file", dest="src_file", default="../data/processed/src-test.txt", help="src file") parser.add_argument("-tgt", "--tgt_file", dest="tgt_file", default="../data/processed/tgt-test.txt", help="target file") args = parser.parse_args() print("scores: \n") eval(args.out_file, args.src_file, args.tgt_file)