import math import collections import string import re import logging from federatedscope.register import register_metric logger = logging.getLogger(__name__) def normalize_answer(s): '''Lower text and remove punctuation, articles and extra whitespace.''' def remove_articles(text): regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) return re.sub(regex, ' ', text) def white_space_fix(text): return ' '.join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return ''.join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def get_tokens(s): if not s: return [] return normalize_answer(s).split() def compute_exact(a_gold, a_pred): return int(normalize_answer(a_gold) == normalize_answer(a_pred)) def compute_f1(a_gold, a_pred): gold_toks = get_tokens(a_gold) pred_toks = get_tokens(a_pred) common = collections.Counter(gold_toks) & collections.Counter(pred_toks) num_same = sum(common.values()) if len(gold_toks) == 0 or len(pred_toks) == 0: # If either is no-answer, then F1 is 1 if they agree, 0 otherwise return int(gold_toks == pred_toks) if num_same == 0: return 0 precision = 1.0 * num_same / len(pred_toks) recall = 1.0 * num_same / len(gold_toks) f1 = (2 * precision * recall) / (precision + recall) return f1 def get_raw_scores(examples, preds): ''' Computes the exact and f1 scores from the examples and the model predictions ''' exact_scores = {} f1_scores = {} for example in examples: qa_id = example.qa_id gold_answers = [ answer['text'] for answer in example.val_answer if normalize_answer(answer['text']) ] if not gold_answers: # For unanswerable questions, only correct answer is empty string gold_answers = [''] if qa_id not in preds: print('Missing prediction for %s' % qa_id) continue prediction = preds[qa_id] exact_scores[qa_id] = max( compute_exact(a, prediction) for a in gold_answers) f1_scores[qa_id] = max(compute_f1(a, prediction) for a in gold_answers) return exact_scores, f1_scores def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): new_scores = {} for qid, s in scores.items(): pred_na = na_probs[qid] > na_prob_thresh if pred_na: new_scores[qid] = float(not qid_to_has_ans[qid]) else: new_scores[qid] = s return new_scores def make_eval_dict(exact_scores, f1_scores, qid_list=None): if not qid_list: total = len(exact_scores) exact = 100.0 * sum(exact_scores.values()) / total f1 = 100.0 * sum(f1_scores.values()) / total return collections.OrderedDict([ ('exact', exact), ('f1', f1), ('exact_and_f1', (exact + f1) / 2), ('total', total), ]) else: total = len(qid_list) exact = 100.0 * sum(exact_scores[k] for k in qid_list) / total f1 = 100.0 * sum(f1_scores[k] for k in qid_list) / total return collections.OrderedDict([ ('exact', exact), ('f1', f1), ('exact_and_f1', (exact + f1) / 2), ('total', total), ]) def merge_eval(main_eval, new_eval, prefix): for k in new_eval: main_eval['%s_%s' % (prefix, k)] = new_eval[k] def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) cur_score = num_no_ans best_score = cur_score best_thresh = 0.0 qid_list = sorted(na_probs, key=lambda k: na_probs[k]) for _, qid in enumerate(qid_list): if qid not in scores: continue if qid_to_has_ans[qid]: diff = scores[qid] else: if preds[qid]: diff = -1 else: diff = 0 cur_score += diff if cur_score > best_score: best_score = cur_score best_thresh = na_probs[qid] return 100.0 * best_score / len(scores), best_thresh def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) main_eval['best_exact'] = best_exact main_eval['best_exact_thresh'] = exact_thresh main_eval['best_f1'] = best_f1 main_eval['best_f1_thresh'] = f1_thresh def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0): qa_id_to_has_answer = { example.qa_id: bool(example.val_answer) for example in examples } has_answer_qids = [ qa_id for qa_id, has_answer in qa_id_to_has_answer.items() if has_answer ] no_answer_qids = [ qa_id for qa_id, has_answer in qa_id_to_has_answer.items() if not has_answer ] if no_answer_probs is None: no_answer_probs = {k: 0.0 for k in preds} exact, f1 = get_raw_scores(examples, preds) exact_threshold = apply_no_ans_threshold(exact, no_answer_probs, qa_id_to_has_answer, no_answer_probability_threshold) f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qa_id_to_has_answer, no_answer_probability_threshold) evaluation = make_eval_dict(exact_threshold, f1_threshold) if has_answer_qids: has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids) merge_eval(evaluation, has_ans_eval, 'HasAns') if no_answer_qids: no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids) merge_eval(evaluation, no_ans_eval, 'NoAns') if no_answer_probs: find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qa_id_to_has_answer) return evaluation def get_final_text(pred_text, orig_text): '''Project the tokenized prediction back to the original text.''' # When we created the data, we kept track of the alignment between original # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So # now `orig_text` contains the span of our original text corresponding # to the span that we predicted. # # However, `orig_text` may contain extra characters that we don't want in # our prediction. # # For example, let's say: # pred_text = steve smith # orig_text = Steve Smith's # # We don't want to return `orig_text` because it contains the extra ''s'. # # We don't want to return `pred_text` because it's already been normalized # (the SQuAD eval script also does punctuation stripping/lower casing but # our tokenizer does additional normalization like stripping accent # characters). # # What we really want to return is 'Steve Smith'. # # Therefore, we have to apply a semi-complicated alignment heuristic # between `pred_text` and `orig_text` to get a character-to-character # alignment. This can fail in certain cases in which case we just return # `orig_text`. from transformers import BasicTokenizer def _strip_spaces(text): ns_chars = [] ns_to_s_map = collections.OrderedDict() for (i, c) in enumerate(text): if c == ' ': continue ns_to_s_map[len(ns_chars)] = i ns_chars.append(c) ns_text = ''.join(ns_chars) return (ns_text, ns_to_s_map) # We first tokenize `orig_text`, strip whitespace from the result # and `pred_text`, and check if they are the same length. If they are # NOT the same length, the heuristic has failed. If they are the same # length, we assume the characters are one-to-one aligned. tokenizer = BasicTokenizer() tok_text = ' '.join(tokenizer.tokenize(orig_text)) start_position = tok_text.find(pred_text) if start_position == -1: return orig_text end_position = start_position + len(pred_text) - 1 (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) if len(orig_ns_text) != len(tok_ns_text): return orig_text # We then project the characters in `pred_text` back to `orig_text` using # the character-to-character alignment. tok_s_to_ns_map = {} for (i, tok_index) in tok_ns_to_s_map.items(): tok_s_to_ns_map[tok_index] = i orig_start_position = None if start_position in tok_s_to_ns_map: ns_start_position = tok_s_to_ns_map[start_position] if ns_start_position in orig_ns_to_s_map: orig_start_position = orig_ns_to_s_map[ns_start_position] if orig_start_position is None: return orig_text orig_end_position = None if end_position in tok_s_to_ns_map: ns_end_position = tok_s_to_ns_map[end_position] if ns_end_position in orig_ns_to_s_map: orig_end_position = orig_ns_to_s_map[ns_end_position] if orig_end_position is None: return orig_text output_text = orig_text[orig_start_position:(orig_end_position + 1)] return output_text def get_topk_indices(logits, n_best_size): index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) topk_indices = [] for i in range(len(index_and_score)): if i >= n_best_size: break topk_indices.append(index_and_score[i][0]) return topk_indices def _compute_softmax(scores): '''Compute softmax probability over raw logits.''' if not scores: return [] max_score = None for score in scores: if max_score is None or score > max_score: max_score = score exp_scores = [] total_sum = 0.0 for score in scores: x = math.exp(score - max_score) exp_scores.append(x) total_sum += x probs = [] for score in exp_scores: probs.append(score / total_sum) return probs def create_squad_answer_texts(examples, encoded_inputs, results, n_best_size, max_answer_len, null_score_diff_threshold): _PrelimPrediction = collections.namedtuple('PrelimPrediction', [ 'feature_index', 'start_index', 'end_index', 'start_logit', 'end_logit' ]) _NbestPrediction = collections.namedtuple( 'NbestPrediction', ['text', 'start_logit', 'end_logit']) example_index_to_features = collections.defaultdict(list) for feature in encoded_inputs: example_index_to_features[feature.example_index].append(feature) unique_id_to_result = {} for result in results: unique_id_to_result[result.unique_id] = result predicted_answer_texts = collections.OrderedDict() for (example_index, example) in enumerate(examples): features = example_index_to_features[example_index] prelim_predictions = [] # keep track of the minimum score of null start+end of position 0 score_null = 1000000 # large and positive min_null_feature_index = 0 # the paragraph slice with min null score null_start_logit = 0 # the start logit at the slice with min null # score null_end_logit = 0 # the end logit at the slice with min null score for (feature_index, feature) in enumerate(features): result = unique_id_to_result[feature.unique_id] start_indexes = get_topk_indices(result.start_logits, n_best_size) end_indexes = get_topk_indices(result.end_logits, n_best_size) # if we could have irrelevant answers, get the min score of # irrelevant feature_null_score = result.start_logits[0] + result.end_logits[0] if feature_null_score < score_null: score_null = feature_null_score min_null_feature_index = feature_index null_start_logit = result.start_logits[0] null_end_logit = result.end_logits[0] for start_index in start_indexes: for end_index in end_indexes: # We could hypothetically create invalid predictions, # e.g., predict that the start of the span is in the # question. We throw out all invalid predictions. if start_index >= len(feature.tokens): continue if end_index >= len(feature.tokens): continue if start_index not in feature.context_subtok_to_tok_idx: continue if end_index not in feature.context_subtok_to_tok_idx: continue if not feature.is_max_context_token.get( start_index, False): continue if end_index < start_index: continue length = end_index - start_index + 1 if length > max_answer_len: continue prelim_predictions.append( _PrelimPrediction( feature_index=feature_index, start_index=start_index, end_index=end_index, start_logit=result.start_logits[start_index], end_logit=result.end_logits[end_index])) prelim_predictions.append( _PrelimPrediction(feature_index=min_null_feature_index, start_index=0, end_index=0, start_logit=null_start_logit, end_logit=null_end_logit)) prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) seen_predictions = {} nbest = [] for pred in prelim_predictions: if len(nbest) >= n_best_size: break feature = features[pred.feature_index] if pred.start_index > 0: # this is a non-null prediction tok_tokens = \ feature.tokens[pred.start_index:(pred.end_index + 1)] orig_doc_start = \ feature.context_subtok_to_tok_idx[pred.start_index] orig_doc_end = \ feature.context_subtok_to_tok_idx[pred.end_index] orig_tokens = \ example.context_tokens[orig_doc_start:(orig_doc_end + 1)] tok_text = ' '.join(tok_tokens) # De-tokenize WordPieces that have been split off. tok_text = tok_text.replace(' ##', '') tok_text = tok_text.replace('##', '') # Clean whitespace tok_text = tok_text.strip() tok_text = ' '.join(tok_text.split()) orig_text = ' '.join(orig_tokens) final_text = get_final_text(tok_text, orig_text) if final_text in seen_predictions: continue seen_predictions[final_text] = True else: final_text = '' seen_predictions[final_text] = True nbest.append( _NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit)) # if we didn't include the empty option in the n-best, include it if '' not in seen_predictions: nbest.append( _NbestPrediction(text='', start_logit=null_start_logit, end_logit=null_end_logit)) # In very rare edge cases we could only have single null prediction. # So we just create a nonce prediction in this case to avoid failure. if len(nbest) == 1: nbest.insert( 0, _NbestPrediction(text='empty', start_logit=0.0, end_logit=0.0)) # In very rare edge cases we could have no valid predictions. So we # just create a nonce prediction in this case to avoid failure. if not nbest: nbest.append( _NbestPrediction(text='empty', start_logit=0.0, end_logit=0.0)) total_scores = [] best_non_null_entry = None for entry in nbest: total_scores.append(entry.start_logit + entry.end_logit) if not best_non_null_entry: if entry.text: best_non_null_entry = entry score_diff = \ score_null - best_non_null_entry.start_logit - \ best_non_null_entry.end_logit if score_diff > null_score_diff_threshold: predicted_answer_texts[example.qa_id] = '' else: predicted_answer_texts[example.qa_id] = best_non_null_entry.text return predicted_answer_texts def compute_squad_metrics(examples, encoded_inputs, results, n_best_size, max_answer_len, null_score_diff_threshold=None, return_text=False): predicted_answer_texts = create_squad_answer_texts( examples, encoded_inputs, results, n_best_size, max_answer_len, null_score_diff_threshold) raw_metrics = squad_evaluate(examples, predicted_answer_texts) metrics = { k: v for k, v in raw_metrics.items() if k in ('exact', 'f1', 'exact_and_f1') } if return_text: return predicted_answer_texts return metrics def load_squad_metrics(ctx, **kwargs): examples = ctx.get('{}_examples'.format(ctx.cur_split)) encoded_inputs = ctx.get('{}_encoded'.format(ctx.cur_split)) results = ctx.squad_results n_best_size = ctx.cfg.model.n_best_size max_answer_len = ctx.cfg.model.max_answer_len null_score_diff_threshold = ctx.cfg.model.null_score_diff_threshold metrics = compute_squad_metrics(examples, encoded_inputs, results, n_best_size, max_answer_len, null_score_diff_threshold) return metrics def call_squad_metric(types): if 'squad' in types: the_larger_the_better = True return 'squad', load_squad_metrics, the_larger_the_better register_metric('squad', call_squad_metric)