From 582f516adb0a2cc175db2a097cbdd4d3cbada9db Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 20 Jan 2021 04:52:13 -0500 Subject: [PATCH] Use datasets squad_v2 metric in run_qa (#9677) --- examples/question-answering/requirements.txt | 2 +- examples/question-answering/run_qa.py | 4 +- .../question-answering/run_qa_beam_search.py | 4 +- .../squad_v2_local/evaluate.py | 322 ------------------ .../squad_v2_local/squad_v2_local.py | 128 ------- 5 files changed, 3 insertions(+), 457 deletions(-) delete mode 100644 examples/question-answering/squad_v2_local/evaluate.py delete mode 100644 examples/question-answering/squad_v2_local/squad_v2_local.py diff --git a/examples/question-answering/requirements.txt b/examples/question-answering/requirements.txt index ff72fc8415..c8205f0d3d 100644 --- a/examples/question-answering/requirements.txt +++ b/examples/question-answering/requirements.txt @@ -1 +1 @@ -datasets >= 1.1.3 +datasets >= 1.2.1 diff --git a/examples/question-answering/run_qa.py b/examples/question-answering/run_qa.py index 4b83806e6c..dc3cce05b9 100644 --- a/examples/question-answering/run_qa.py +++ b/examples/question-answering/run_qa.py @@ -433,9 +433,7 @@ def main(): references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]] return EvalPrediction(predictions=formatted_predictions, label_ids=references) - # TODO: Once the fix lands in a Datasets release, remove the _local here and the squad_v2_local folder. - current_dir = os.path.sep.join(os.path.join(__file__).split(os.path.sep)[:-1]) - metric = load_metric(os.path.join(current_dir, "squad_v2_local") if data_args.version_2_with_negative else "squad") + metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") def compute_metrics(p: EvalPrediction): return metric.compute(predictions=p.predictions, references=p.label_ids) diff --git a/examples/question-answering/run_qa_beam_search.py b/examples/question-answering/run_qa_beam_search.py index 0a2846bc68..6d343ce766 100644 --- a/examples/question-answering/run_qa_beam_search.py +++ b/examples/question-answering/run_qa_beam_search.py @@ -472,9 +472,7 @@ def main(): references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in datasets["validation"]] return EvalPrediction(predictions=formatted_predictions, label_ids=references) - # TODO: Once the fix lands in a Datasets release, remove the _local here and the squad_v2_local folder. - current_dir = os.path.sep.join(os.path.join(__file__).split(os.path.sep)[:-1]) - metric = load_metric(os.path.join(current_dir, "squad_v2_local") if data_args.version_2_with_negative else "squad") + metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") def compute_metrics(p: EvalPrediction): return metric.compute(predictions=p.predictions, references=p.label_ids) diff --git a/examples/question-answering/squad_v2_local/evaluate.py b/examples/question-answering/squad_v2_local/evaluate.py deleted file mode 100644 index 549688a735..0000000000 --- a/examples/question-answering/squad_v2_local/evaluate.py +++ /dev/null @@ -1,322 +0,0 @@ -"""Official evaluation script for SQuAD version 2.0. - -In addition to basic functionality, we also compute additional statistics and -plot precision-recall curves if an additional na_prob.json file is provided. -This file is expected to map question ID's to the model's predicted probability -that a question is unanswerable. -""" -import argparse -import collections -import json -import os -import re -import string -import sys - -import numpy as np - - -OPTS = None - - -def parse_args(): - parser = argparse.ArgumentParser("Official evaluation script for SQuAD version 2.0.") - parser.add_argument("data_file", metavar="data.json", help="Input data JSON file.") - parser.add_argument("pred_file", metavar="pred.json", help="Model predictions.") - parser.add_argument( - "--out-file", "-o", metavar="eval.json", help="Write accuracy metrics to file (default is stdout)." - ) - parser.add_argument( - "--na-prob-file", "-n", metavar="na_prob.json", help="Model estimates of probability of no answer." - ) - parser.add_argument( - "--na-prob-thresh", - "-t", - type=float, - default=1.0, - help='Predict "" if no-answer probability exceeds this (default = 1.0).', - ) - parser.add_argument( - "--out-image-dir", "-p", metavar="out_images", default=None, help="Save precision-recall curves to directory." - ) - parser.add_argument("--verbose", "-v", action="store_true") - if len(sys.argv) == 1: - parser.print_help() - sys.exit(1) - return parser.parse_args() - - -def make_qid_to_has_ans(dataset): - qid_to_has_ans = {} - for article in dataset: - for p in article["paragraphs"]: - for qa in p["qas"]: - qid_to_has_ans[qa["id"]] = bool(qa["answers"]["text"]) - return qid_to_has_ans - - -def normalize_answer(s): - """Lower text and remove punctuation, articles and extra whitespace.""" - - def remove_articles(text): - regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) - return re.sub(regex, " ", text) - - def white_space_fix(text): - return " ".join(text.split()) - - def remove_punc(text): - exclude = set(string.punctuation) - return "".join(ch for ch in text if ch not in exclude) - - def lower(text): - return text.lower() - - return white_space_fix(remove_articles(remove_punc(lower(s)))) - - -def get_tokens(s): - if not s: - return [] - return normalize_answer(s).split() - - -def compute_exact(a_gold, a_pred): - return int(normalize_answer(a_gold) == normalize_answer(a_pred)) - - -def compute_f1(a_gold, a_pred): - gold_toks = get_tokens(a_gold) - pred_toks = get_tokens(a_pred) - common = collections.Counter(gold_toks) & collections.Counter(pred_toks) - num_same = sum(common.values()) - if len(gold_toks) == 0 or len(pred_toks) == 0: - # If either is no-answer, then F1 is 1 if they agree, 0 otherwise - return int(gold_toks == pred_toks) - if num_same == 0: - return 0 - precision = 1.0 * num_same / len(pred_toks) - recall = 1.0 * num_same / len(gold_toks) - f1 = (2 * precision * recall) / (precision + recall) - return f1 - - -def get_raw_scores(dataset, preds): - exact_scores = {} - f1_scores = {} - for article in dataset: - for p in article["paragraphs"]: - for qa in p["qas"]: - qid = qa["id"] - gold_answers = [t for t in qa["answers"]["text"] if normalize_answer(t)] - if not gold_answers: - # For unanswerable questions, only correct answer is empty string - gold_answers = [""] - if qid not in preds: - print("Missing prediction for %s" % qid) - continue - a_pred = preds[qid] - # Take max over all gold answers - exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) - f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) - return exact_scores, f1_scores - - -def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): - new_scores = {} - for qid, s in scores.items(): - pred_na = na_probs[qid] > na_prob_thresh - if pred_na: - new_scores[qid] = float(not qid_to_has_ans[qid]) - else: - new_scores[qid] = s - return new_scores - - -def make_eval_dict(exact_scores, f1_scores, qid_list=None): - if not qid_list: - total = len(exact_scores) - return collections.OrderedDict( - [ - ("exact", 100.0 * sum(exact_scores.values()) / total), - ("f1", 100.0 * sum(f1_scores.values()) / total), - ("total", total), - ] - ) - else: - total = len(qid_list) - return collections.OrderedDict( - [ - ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total), - ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total), - ("total", total), - ] - ) - - -def merge_eval(main_eval, new_eval, prefix): - for k in new_eval: - main_eval["%s_%s" % (prefix, k)] = new_eval[k] - - -def plot_pr_curve(precisions, recalls, out_image, title): - plt.step(recalls, precisions, color="b", alpha=0.2, where="post") - plt.fill_between(recalls, precisions, step="post", alpha=0.2, color="b") - plt.xlabel("Recall") - plt.ylabel("Precision") - plt.xlim([0.0, 1.05]) - plt.ylim([0.0, 1.05]) - plt.title(title) - plt.savefig(out_image) - plt.clf() - - -def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None): - qid_list = sorted(na_probs, key=lambda k: na_probs[k]) - true_pos = 0.0 - cur_p = 1.0 - cur_r = 0.0 - precisions = [1.0] - recalls = [0.0] - avg_prec = 0.0 - for i, qid in enumerate(qid_list): - if qid_to_has_ans[qid]: - true_pos += scores[qid] - cur_p = true_pos / float(i + 1) - cur_r = true_pos / float(num_true_pos) - if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]: - # i.e., if we can put a threshold after this point - avg_prec += cur_p * (cur_r - recalls[-1]) - precisions.append(cur_p) - recalls.append(cur_r) - if out_image: - plot_pr_curve(precisions, recalls, out_image, title) - return {"ap": 100.0 * avg_prec} - - -def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, out_image_dir): - if out_image_dir and not os.path.exists(out_image_dir): - os.makedirs(out_image_dir) - num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) - if num_true_pos == 0: - return - pr_exact = make_precision_recall_eval( - exact_raw, - na_probs, - num_true_pos, - qid_to_has_ans, - out_image=os.path.join(out_image_dir, "pr_exact.png"), - title="Precision-Recall curve for Exact Match score", - ) - pr_f1 = make_precision_recall_eval( - f1_raw, - na_probs, - num_true_pos, - qid_to_has_ans, - out_image=os.path.join(out_image_dir, "pr_f1.png"), - title="Precision-Recall curve for F1 score", - ) - oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} - pr_oracle = make_precision_recall_eval( - oracle_scores, - na_probs, - num_true_pos, - qid_to_has_ans, - out_image=os.path.join(out_image_dir, "pr_oracle.png"), - title="Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)", - ) - merge_eval(main_eval, pr_exact, "pr_exact") - merge_eval(main_eval, pr_f1, "pr_f1") - merge_eval(main_eval, pr_oracle, "pr_oracle") - - -def histogram_na_prob(na_probs, qid_list, image_dir, name): - if not qid_list: - return - x = [na_probs[k] for k in qid_list] - weights = np.ones_like(x) / float(len(x)) - plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) - plt.xlabel("Model probability of no-answer") - plt.ylabel("Proportion of dataset") - plt.title("Histogram of no-answer probability: %s" % name) - plt.savefig(os.path.join(image_dir, "na_prob_hist_%s.png" % name)) - plt.clf() - - -def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): - num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) - cur_score = num_no_ans - best_score = cur_score - best_thresh = 0.0 - qid_list = sorted(na_probs, key=lambda k: na_probs[k]) - for i, qid in enumerate(qid_list): - if qid not in scores: - continue - if qid_to_has_ans[qid]: - diff = scores[qid] - else: - if preds[qid]: - diff = -1 - else: - diff = 0 - cur_score += diff - if cur_score > best_score: - best_score = cur_score - best_thresh = na_probs[qid] - return 100.0 * best_score / len(scores), best_thresh - - -def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): - best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) - best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) - main_eval["best_exact"] = best_exact - main_eval["best_exact_thresh"] = exact_thresh - main_eval["best_f1"] = best_f1 - main_eval["best_f1_thresh"] = f1_thresh - - -def main(): - with open(OPTS.data_file) as f: - dataset_json = json.load(f) - dataset = dataset_json["data"] - with open(OPTS.pred_file) as f: - preds = json.load(f) - if OPTS.na_prob_file: - with open(OPTS.na_prob_file) as f: - na_probs = json.load(f) - else: - na_probs = {k: 0.0 for k in preds} - qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False - has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] - no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] - exact_raw, f1_raw = get_raw_scores(dataset, preds) - exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh) - f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh) - out_eval = make_eval_dict(exact_thresh, f1_thresh) - if has_ans_qids: - has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) - merge_eval(out_eval, has_ans_eval, "HasAns") - if no_ans_qids: - no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) - merge_eval(out_eval, no_ans_eval, "NoAns") - if OPTS.na_prob_file: - find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) - if OPTS.na_prob_file and OPTS.out_image_dir: - run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, OPTS.out_image_dir) - histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, "hasAns") - histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, "noAns") - if OPTS.out_file: - with open(OPTS.out_file, "w") as f: - json.dump(out_eval, f) - else: - print(json.dumps(out_eval, indent=2)) - - -if __name__ == "__main__": - OPTS = parse_args() - if OPTS.out_image_dir: - import matplotlib - - matplotlib.use("Agg") - import matplotlib.pyplot as plt - main() diff --git a/examples/question-answering/squad_v2_local/squad_v2_local.py b/examples/question-answering/squad_v2_local/squad_v2_local.py deleted file mode 100644 index d1b7b45623..0000000000 --- a/examples/question-answering/squad_v2_local/squad_v2_local.py +++ /dev/null @@ -1,128 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Datasets Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" SQuAD v2 metric. """ - -import datasets - -from .evaluate import ( - apply_no_ans_threshold, - find_all_best_thresh, - get_raw_scores, - make_eval_dict, - make_qid_to_has_ans, - merge_eval, -) - - -_CITATION = """\ -@inproceedings{Rajpurkar2016SQuAD10, - title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text}, - author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang}, - booktitle={EMNLP}, - year={2016} -} -""" - -_DESCRIPTION = """ -This metric wrap the official scoring script for version 2 of the Stanford Question -Answering Dataset (SQuAD). - -Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by -crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, -from the corresponding reading passage, or the question might be unanswerable. - -SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable questions -written adversarially by crowdworkers to look similar to answerable ones. -To do well on SQuAD2.0, systems must not only answer questions when possible, but also -determine when no answer is supported by the paragraph and abstain from answering. -""" - -_KWARGS_DESCRIPTION = """ -Computes SQuAD v2 scores (F1 and EM). -Args: - predictions: List of triple for question-answers to score with the following elements: - - the question-answer 'id' field as given in the references (see below) - - the text of the answer - - the probability that the question has no answer - references: List of question-answers dictionaries with the following key-values: - - 'id': id of the question-answer pair (see above), - - 'answers': a list of Dict {'text': text of the answer as a string} - no_answer_threshold: float - Probability threshold to decide that a question has no answer. -Returns: - 'exact': Exact match (the normalized answer exactly match the gold answer) - 'f1': The F-score of predicted tokens versus the gold answer - 'total': Number of score considered - 'HasAns_exact': Exact match (the normalized answer exactly match the gold answer) - 'HasAns_f1': The F-score of predicted tokens versus the gold answer - 'HasAns_total': Number of score considered - 'NoAns_exact': Exact match (the normalized answer exactly match the gold answer) - 'NoAns_f1': The F-score of predicted tokens versus the gold answer - 'NoAns_total': Number of score considered - 'best_exact': Best exact match (with varying threshold) - 'best_exact_thresh': No-answer probability threshold associated to the best exact match - 'best_f1': Best F1 (with varying threshold) - 'best_f1_thresh': No-answer probability threshold associated to the best F1 -""" - - -class SquadV2(datasets.Metric): - def _info(self): - return datasets.MetricInfo( - description=_DESCRIPTION, - citation=_CITATION, - inputs_description=_KWARGS_DESCRIPTION, - features=datasets.Features( - { - "predictions": { - "id": datasets.Value("string"), - "prediction_text": datasets.Value("string"), - "no_answer_probability": datasets.Value("float32"), - }, - "references": { - "id": datasets.Value("string"), - "answers": datasets.features.Sequence( - {"text": datasets.Value("string"), "answer_start": datasets.Value("int32")} - ), - }, - } - ), - codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], - reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], - ) - - def _compute(self, predictions, references, no_answer_threshold=1.0): - no_answer_probabilities = dict((p["id"], p["no_answer_probability"]) for p in predictions) - dataset = [{"paragraphs": [{"qas": references}]}] - predictions = dict((p["id"], p["prediction_text"]) for p in predictions) - - qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False - has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] - no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] - - exact_raw, f1_raw = get_raw_scores(dataset, predictions) - exact_thresh = apply_no_ans_threshold(exact_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold) - f1_thresh = apply_no_ans_threshold(f1_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold) - out_eval = make_eval_dict(exact_thresh, f1_thresh) - - if has_ans_qids: - has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) - merge_eval(out_eval, has_ans_eval, "HasAns") - if no_ans_qids: - no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) - merge_eval(out_eval, no_ans_eval, "NoAns") - find_all_best_thresh(out_eval, predictions, exact_raw, f1_raw, no_answer_probabilities, qid_to_has_ans) - - return out_eval