98 lines
3.0 KiB
Python
98 lines
3.0 KiB
Python
"""
|
|
The implementations are adapted from https://github.com/hugochan/
|
|
RL-based-Graph2Seq-for-NQG/blob/master/src/core/evaluation/eval.py and
|
|
https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py
|
|
"""
|
|
|
|
from json import encoder
|
|
from collections import defaultdict
|
|
from federatedscope.nlp.metric.bleu import Bleu
|
|
from federatedscope.nlp.metric.meteor import Meteor
|
|
|
|
encoder.FLOAT_REPR = lambda o: format(o, '.4f')
|
|
|
|
|
|
class QGEvalCap:
|
|
def __init__(self, gts, res):
|
|
self.gts = gts
|
|
self.res = res
|
|
|
|
def evaluate(self, include_meteor=True, verbose=False):
|
|
output = {}
|
|
scorers = [
|
|
(Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
|
|
]
|
|
if include_meteor:
|
|
scorers.append((Meteor(), "METEOR"))
|
|
|
|
for scorer, method in scorers:
|
|
score, scores = scorer.compute_score(self.gts, self.res)
|
|
if type(method) == list:
|
|
for sc, scs, m in zip(score, scores, method):
|
|
if verbose:
|
|
print("%s: %0.5f" % (m, sc))
|
|
output[m] = sc
|
|
else:
|
|
if verbose:
|
|
print("%s: %0.5f" % (method, score))
|
|
output[method] = score
|
|
return output
|
|
|
|
|
|
def eval(out_file, src_file, tgt_file):
|
|
pairs = []
|
|
with open(src_file, 'r', encoding='utf-8') as infile:
|
|
for line in infile:
|
|
pair = {}
|
|
pair['tokenized_sentence'] = line[:-1]
|
|
pairs.append(pair)
|
|
|
|
with open(tgt_file, 'r', encoding='utf-8') as infile:
|
|
cnt = 0
|
|
for line in infile:
|
|
pairs[cnt]['tokenized_question'] = line[:-1]
|
|
cnt += 1
|
|
|
|
output = []
|
|
with open(out_file, 'r', encoding='utf-8') as infile:
|
|
for line in infile:
|
|
line = line[:-1]
|
|
output.append(line)
|
|
|
|
for idx, pair in enumerate(pairs):
|
|
pair['prediction'] = output[idx]
|
|
|
|
res = defaultdict(lambda: [])
|
|
gts = defaultdict(lambda: [])
|
|
for pair in pairs[:]:
|
|
key = pair['tokenized_sentence']
|
|
res[key] = [pair['prediction']]
|
|
gts[key].append(pair['tokenized_question'])
|
|
|
|
QGEval = QGEvalCap(gts, res)
|
|
return QGEval.evaluate()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from argparse import ArgumentParser
|
|
parser = ArgumentParser()
|
|
parser.add_argument("-out",
|
|
"--out_file",
|
|
dest="out_file",
|
|
default="./output/pred.txt",
|
|
help="output file to compare")
|
|
parser.add_argument("-src",
|
|
"--src_file",
|
|
dest="src_file",
|
|
default="../data/processed/src-test.txt",
|
|
help="src file")
|
|
parser.add_argument("-tgt",
|
|
"--tgt_file",
|
|
dest="tgt_file",
|
|
default="../data/processed/tgt-test.txt",
|
|
help="target file")
|
|
args = parser.parse_args()
|
|
|
|
print("scores: \n")
|
|
eval(args.out_file, args.src_file, args.tgt_file)
|