FS-TFP/federatedscope/nlp/hetero_tasks/metric/squad.py

537 lines
19 KiB
Python

import math
import collections
import string
import re
import logging
from federatedscope.register import register_metric
logger = logging.getLogger(__name__)
def normalize_answer(s):
'''Lower text and remove punctuation, articles and extra whitespace.'''
def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()
def compute_exact(a_gold, a_pred):
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
def compute_f1(a_gold, a_pred):
gold_toks = get_tokens(a_gold)
pred_toks = get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def get_raw_scores(examples, preds):
'''
Computes the exact and f1 scores from the examples and the model
predictions
'''
exact_scores = {}
f1_scores = {}
for example in examples:
qa_id = example.qa_id
gold_answers = [
answer['text'] for answer in example.val_answer
if normalize_answer(answer['text'])
]
if not gold_answers:
# For unanswerable questions, only correct answer is empty string
gold_answers = ['']
if qa_id not in preds:
print('Missing prediction for %s' % qa_id)
continue
prediction = preds[qa_id]
exact_scores[qa_id] = max(
compute_exact(a, prediction) for a in gold_answers)
f1_scores[qa_id] = max(compute_f1(a, prediction) for a in gold_answers)
return exact_scores, f1_scores
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
new_scores = {}
for qid, s in scores.items():
pred_na = na_probs[qid] > na_prob_thresh
if pred_na:
new_scores[qid] = float(not qid_to_has_ans[qid])
else:
new_scores[qid] = s
return new_scores
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
if not qid_list:
total = len(exact_scores)
exact = 100.0 * sum(exact_scores.values()) / total
f1 = 100.0 * sum(f1_scores.values()) / total
return collections.OrderedDict([
('exact', exact),
('f1', f1),
('exact_and_f1', (exact + f1) / 2),
('total', total),
])
else:
total = len(qid_list)
exact = 100.0 * sum(exact_scores[k] for k in qid_list) / total
f1 = 100.0 * sum(f1_scores[k] for k in qid_list) / total
return collections.OrderedDict([
('exact', exact),
('f1', f1),
('exact_and_f1', (exact + f1) / 2),
('total', total),
])
def merge_eval(main_eval, new_eval, prefix):
for k in new_eval:
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for _, qid in enumerate(qid_list):
if qid not in scores:
continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if preds[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
return 100.0 * best_score / len(scores), best_thresh
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs,
qid_to_has_ans):
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs,
qid_to_has_ans)
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs,
qid_to_has_ans)
main_eval['best_exact'] = best_exact
main_eval['best_exact_thresh'] = exact_thresh
main_eval['best_f1'] = best_f1
main_eval['best_f1_thresh'] = f1_thresh
def squad_evaluate(examples,
preds,
no_answer_probs=None,
no_answer_probability_threshold=1.0):
qa_id_to_has_answer = {
example.qa_id: bool(example.val_answer)
for example in examples
}
has_answer_qids = [
qa_id for qa_id, has_answer in qa_id_to_has_answer.items()
if has_answer
]
no_answer_qids = [
qa_id for qa_id, has_answer in qa_id_to_has_answer.items()
if not has_answer
]
if no_answer_probs is None:
no_answer_probs = {k: 0.0 for k in preds}
exact, f1 = get_raw_scores(examples, preds)
exact_threshold = apply_no_ans_threshold(exact, no_answer_probs,
qa_id_to_has_answer,
no_answer_probability_threshold)
f1_threshold = apply_no_ans_threshold(f1, no_answer_probs,
qa_id_to_has_answer,
no_answer_probability_threshold)
evaluation = make_eval_dict(exact_threshold, f1_threshold)
if has_answer_qids:
has_ans_eval = make_eval_dict(exact_threshold,
f1_threshold,
qid_list=has_answer_qids)
merge_eval(evaluation, has_ans_eval, 'HasAns')
if no_answer_qids:
no_ans_eval = make_eval_dict(exact_threshold,
f1_threshold,
qid_list=no_answer_qids)
merge_eval(evaluation, no_ans_eval, 'NoAns')
if no_answer_probs:
find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs,
qa_id_to_has_answer)
return evaluation
def get_final_text(pred_text, orig_text):
'''Project the tokenized prediction back to the original text.'''
# When we created the data, we kept track of the alignment between original
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# now `orig_text` contains the span of our original text corresponding
# to the span that we predicted.
#
# However, `orig_text` may contain extra characters that we don't want in
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra ''s'.
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is 'Steve Smith'.
#
# Therefore, we have to apply a semi-complicated alignment heuristic
# between `pred_text` and `orig_text` to get a character-to-character
# alignment. This can fail in certain cases in which case we just return
# `orig_text`.
from transformers import BasicTokenizer
def _strip_spaces(text):
ns_chars = []
ns_to_s_map = collections.OrderedDict()
for (i, c) in enumerate(text):
if c == ' ':
continue
ns_to_s_map[len(ns_chars)] = i
ns_chars.append(c)
ns_text = ''.join(ns_chars)
return (ns_text, ns_to_s_map)
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer = BasicTokenizer()
tok_text = ' '.join(tokenizer.tokenize(orig_text))
start_position = tok_text.find(pred_text)
if start_position == -1:
return orig_text
end_position = start_position + len(pred_text) - 1
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
if len(orig_ns_text) != len(tok_ns_text):
return orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map = {}
for (i, tok_index) in tok_ns_to_s_map.items():
tok_s_to_ns_map[tok_index] = i
orig_start_position = None
if start_position in tok_s_to_ns_map:
ns_start_position = tok_s_to_ns_map[start_position]
if ns_start_position in orig_ns_to_s_map:
orig_start_position = orig_ns_to_s_map[ns_start_position]
if orig_start_position is None:
return orig_text
orig_end_position = None
if end_position in tok_s_to_ns_map:
ns_end_position = tok_s_to_ns_map[end_position]
if ns_end_position in orig_ns_to_s_map:
orig_end_position = orig_ns_to_s_map[ns_end_position]
if orig_end_position is None:
return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
return output_text
def get_topk_indices(logits, n_best_size):
index_and_score = sorted(enumerate(logits),
key=lambda x: x[1],
reverse=True)
topk_indices = []
for i in range(len(index_and_score)):
if i >= n_best_size:
break
topk_indices.append(index_and_score[i][0])
return topk_indices
def _compute_softmax(scores):
'''Compute softmax probability over raw logits.'''
if not scores:
return []
max_score = None
for score in scores:
if max_score is None or score > max_score:
max_score = score
exp_scores = []
total_sum = 0.0
for score in scores:
x = math.exp(score - max_score)
exp_scores.append(x)
total_sum += x
probs = []
for score in exp_scores:
probs.append(score / total_sum)
return probs
def create_squad_answer_texts(examples, encoded_inputs, results, n_best_size,
max_answer_len, null_score_diff_threshold):
_PrelimPrediction = collections.namedtuple('PrelimPrediction', [
'feature_index', 'start_index', 'end_index', 'start_logit', 'end_logit'
])
_NbestPrediction = collections.namedtuple(
'NbestPrediction', ['text', 'start_logit', 'end_logit'])
example_index_to_features = collections.defaultdict(list)
for feature in encoded_inputs:
example_index_to_features[feature.example_index].append(feature)
unique_id_to_result = {}
for result in results:
unique_id_to_result[result.unique_id] = result
predicted_answer_texts = collections.OrderedDict()
for (example_index, example) in enumerate(examples):
features = example_index_to_features[example_index]
prelim_predictions = []
# keep track of the minimum score of null start+end of position 0
score_null = 1000000 # large and positive
min_null_feature_index = 0 # the paragraph slice with min null score
null_start_logit = 0 # the start logit at the slice with min null
# score
null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features):
result = unique_id_to_result[feature.unique_id]
start_indexes = get_topk_indices(result.start_logits, n_best_size)
end_indexes = get_topk_indices(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of
# irrelevant
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null:
score_null = feature_null_score
min_null_feature_index = feature_index
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indexes:
for end_index in end_indexes:
# We could hypothetically create invalid predictions,
# e.g., predict that the start of the span is in the
# question. We throw out all invalid predictions.
if start_index >= len(feature.tokens):
continue
if end_index >= len(feature.tokens):
continue
if start_index not in feature.context_subtok_to_tok_idx:
continue
if end_index not in feature.context_subtok_to_tok_idx:
continue
if not feature.is_max_context_token.get(
start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_len:
continue
prelim_predictions.append(
_PrelimPrediction(
feature_index=feature_index,
start_index=start_index,
end_index=end_index,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]))
prelim_predictions.append(
_PrelimPrediction(feature_index=min_null_feature_index,
start_index=0,
end_index=0,
start_logit=null_start_logit,
end_logit=null_end_logit))
prelim_predictions = sorted(prelim_predictions,
key=lambda x:
(x.start_logit + x.end_logit),
reverse=True)
seen_predictions = {}
nbest = []
for pred in prelim_predictions:
if len(nbest) >= n_best_size:
break
feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = \
feature.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = \
feature.context_subtok_to_tok_idx[pred.start_index]
orig_doc_end = \
feature.context_subtok_to_tok_idx[pred.end_index]
orig_tokens = \
example.context_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_text = ' '.join(tok_tokens)
# De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(' ##', '')
tok_text = tok_text.replace('##', '')
# Clean whitespace
tok_text = tok_text.strip()
tok_text = ' '.join(tok_text.split())
orig_text = ' '.join(orig_tokens)
final_text = get_final_text(tok_text, orig_text)
if final_text in seen_predictions:
continue
seen_predictions[final_text] = True
else:
final_text = ''
seen_predictions[final_text] = True
nbest.append(
_NbestPrediction(text=final_text,
start_logit=pred.start_logit,
end_logit=pred.end_logit))
# if we didn't include the empty option in the n-best, include it
if '' not in seen_predictions:
nbest.append(
_NbestPrediction(text='',
start_logit=null_start_logit,
end_logit=null_end_logit))
# In very rare edge cases we could only have single null prediction.
# So we just create a nonce prediction in this case to avoid failure.
if len(nbest) == 1:
nbest.insert(
0,
_NbestPrediction(text='empty', start_logit=0.0, end_logit=0.0))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
nbest.append(
_NbestPrediction(text='empty', start_logit=0.0, end_logit=0.0))
total_scores = []
best_non_null_entry = None
for entry in nbest:
total_scores.append(entry.start_logit + entry.end_logit)
if not best_non_null_entry:
if entry.text:
best_non_null_entry = entry
score_diff = \
score_null - best_non_null_entry.start_logit - \
best_non_null_entry.end_logit
if score_diff > null_score_diff_threshold:
predicted_answer_texts[example.qa_id] = ''
else:
predicted_answer_texts[example.qa_id] = best_non_null_entry.text
return predicted_answer_texts
def compute_squad_metrics(examples,
encoded_inputs,
results,
n_best_size,
max_answer_len,
null_score_diff_threshold=None,
return_text=False):
predicted_answer_texts = create_squad_answer_texts(
examples, encoded_inputs, results, n_best_size, max_answer_len,
null_score_diff_threshold)
raw_metrics = squad_evaluate(examples, predicted_answer_texts)
metrics = {
k: v
for k, v in raw_metrics.items() if k in ('exact', 'f1', 'exact_and_f1')
}
if return_text:
return predicted_answer_texts
return metrics
def load_squad_metrics(ctx, **kwargs):
examples = ctx.get('{}_examples'.format(ctx.cur_split))
encoded_inputs = ctx.get('{}_encoded'.format(ctx.cur_split))
results = ctx.squad_results
n_best_size = ctx.cfg.model.n_best_size
max_answer_len = ctx.cfg.model.max_answer_len
null_score_diff_threshold = ctx.cfg.model.null_score_diff_threshold
metrics = compute_squad_metrics(examples, encoded_inputs, results,
n_best_size, max_answer_len,
null_score_diff_threshold)
return metrics
def call_squad_metric(types):
if 'squad' in types:
the_larger_the_better = True
return 'squad', load_squad_metrics, the_larger_the_better
register_metric('squad', call_squad_metric)