From e3cb7a0b60890cdc0b62c7f3f7ffe87cb6b3f799 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 21 Jun 2021 16:37:28 +0100 Subject: [PATCH] Tensorflow QA example (#12252) * New Tensorflow QA example! * Style pass * Updating README.md for the new example * flake8 fixes * Update examples/tensorflow/question-answering/README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../tensorflow/question-answering/README.md | 55 +- .../tensorflow/question-answering/run_qa.py | 694 ++++++++++++++++++ .../question-answering/run_tf_squad.py | 255 ------- .../tensorflow/question-answering/utils_qa.py | 425 +++++++++++ 4 files changed, 1157 insertions(+), 272 deletions(-) create mode 100755 examples/tensorflow/question-answering/run_qa.py delete mode 100755 examples/tensorflow/question-answering/run_tf_squad.py create mode 100644 examples/tensorflow/question-answering/utils_qa.py diff --git a/examples/tensorflow/question-answering/README.md b/examples/tensorflow/question-answering/README.md index 00c2d5f809..b7c0443b1b 100644 --- a/examples/tensorflow/question-answering/README.md +++ b/examples/tensorflow/question-answering/README.md @@ -1,5 +1,5 @@ -## SQuAD with the Tensorflow Trainer +# Question answering example -```bash -python run_tf_squad.py \ - --model_name_or_path bert-base-uncased \ - --output_dir model \ - --max_seq_length 384 \ - --num_train_epochs 2 \ - --per_gpu_train_batch_size 8 \ - --per_gpu_eval_batch_size 16 \ - --do_train \ - --logging_dir logs \ - --logging_steps 10 \ - --learning_rate 3e-5 \ - --doc_stride 128 +This folder contains the `run_qa.py` script, demonstrating *question answering* with the 🤗 Transformers library. +For straightforward use-cases you may be able to use this script without modification, although we have also +included comments in the code to indicate areas that you may need to adapt to your own projects. + +### Usage notes +Note that when contexts are long they may be split into multiple training cases, not all of which may contain +the answer span. + +As-is, the example script will train on SQuAD or any other question-answering dataset formatted the same way, and can handle user +inputs as well. + +### Multi-GPU and TPU usage + +By default, the script uses a `MirroredStrategy` and will use multiple GPUs effectively if they are available. TPUs +can also be used by passing the name of the TPU resource with the `--tpu` argument. There are some issues surrounding +these strategies and our models right now, which are most likely to appear in the evaluation/prediction steps. We're +actively working on better support for multi-GPU and TPU training in TF, but if you encounter problems a quick +workaround is to train in the multi-GPU or TPU context and then perform predictions outside of it. + +### Memory usage and data loading + +One thing to note is that all data is loaded into memory in this script. Most question answering datasets are small +enough that this is not an issue, but if you have a very large dataset you will need to modify the script to handle +data streaming. This is particularly challenging for TPUs, given the stricter requirements and the sheer volume of data +required to keep them fed. A full explanation of all the possible pitfalls is a bit beyond this example script and +README, but for more information you can see the 'Input Datasets' section of +[this document](https://www.tensorflow.org/guide/tpu). + +### Example command +``` +python run_qa.py \ +--model_name_or_path distilbert-base-cased \ +--output_dir output \ +--dataset_name squad \ +--do_train \ +--do_eval \ ``` - -For the moment evaluation is not available in the Tensorflow Trainer only the training. diff --git a/examples/tensorflow/question-answering/run_qa.py b/examples/tensorflow/question-answering/run_qa.py new file mode 100755 index 0000000000..fe6bf58658 --- /dev/null +++ b/examples/tensorflow/question-answering/run_qa.py @@ -0,0 +1,694 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for question answering. +""" +# You can also adapt this script on your own question answering task. Pointers for this are left as comments. + +import logging +import os +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import tensorflow as tf +from datasets import load_dataset, load_metric + +import transformers +from transformers import ( + AutoConfig, + AutoTokenizer, + EvalPrediction, + HfArgumentParser, + PreTrainedTokenizerFast, + TFAutoModelForQuestionAnswering, + TFTrainingArguments, + set_seed, +) +from transformers.file_utils import CONFIG_NAME, TF2_WEIGHTS_NAME +from transformers.utils import check_min_version +from utils_qa import postprocess_qa_predictions + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.7.0.dev0") + +logger = logging.getLogger(__name__) + + +# region Arguments +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + test_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."}, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_seq_length: int = field( + default=384, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can " + "be faster on GPU but will be slower on TPU)." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + version_2_with_negative: bool = field( + default=False, metadata={"help": "If true, some of the examples do not have an answer."} + ) + null_score_diff_threshold: float = field( + default=0.0, + metadata={ + "help": "The threshold used to select the null answer: if the best answer has a score that is less than " + "the score of the null answer minus this threshold, the null answer is selected for this example. " + "Only useful when `version_2_with_negative=True`." + }, + ) + doc_stride: int = field( + default=128, + metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."}, + ) + n_best_size: int = field( + default=20, + metadata={"help": "The total number of n-best predictions to generate when looking for an answer."}, + ) + max_answer_length: int = field( + default=30, + metadata={ + "help": "The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another." + }, + ) + + def __post_init__(self): + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + and self.test_file is None + ): + raise ValueError("Need either a dataset name or a training/validation file/test_file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if self.test_file is not None: + extension = self.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." + + +# endregion + +# region Helper classes +class SavePretrainedCallback(tf.keras.callbacks.Callback): + # Hugging Face models have a save_pretrained() method that saves both the weights and the necessary + # metadata to allow them to be loaded as a pretrained model in future. This is a simple Keras callback + # that saves the model with this method after each epoch. + def __init__(self, output_dir, **kwargs): + super().__init__() + self.output_dir = output_dir + + def on_epoch_end(self, epoch, logs=None): + self.model.save_pretrained(self.output_dir) + + +def convert_dataset_for_tensorflow( + dataset, batch_size, dataset_mode="variable_batch", shuffle=True, drop_remainder=True +): + """Converts a Hugging Face dataset to a Tensorflow Dataset. The dataset_mode controls whether we pad all batches + to the maximum sequence length, or whether we only pad to the maximum length within that batch. The former + is most useful when training on TPU, as a new graph compilation is required for each sequence length. + """ + + def densify_ragged_batch(features, label=None): + features = { + feature: ragged_tensor.to_tensor(shape=batch_shape[feature]) if feature in tensor_keys else ragged_tensor + for feature, ragged_tensor in features.items() + } + if label is None: + return features + else: + return features, label + + tensor_keys = ["attention_mask", "input_ids"] + label_keys = ["start_positions", "end_positions"] + if dataset_mode == "variable_batch": + batch_shape = {key: None for key in tensor_keys} + data = {key: tf.ragged.constant(dataset[key]) for key in tensor_keys} + elif dataset_mode == "constant_batch": + data = {key: tf.ragged.constant(dataset[key]) for key in tensor_keys} + batch_shape = { + key: tf.concat(([batch_size], ragged_tensor.bounding_shape()[1:]), axis=0) + for key, ragged_tensor in data.items() + } + else: + raise ValueError("Unknown dataset mode!") + + if all([key in dataset.features for key in label_keys]): + for key in label_keys: + data[key] = tf.convert_to_tensor(dataset[key]) + dummy_labels = tf.zeros_like(dataset[key]) + tf_dataset = tf.data.Dataset.from_tensor_slices((data, dummy_labels)) + else: + tf_dataset = tf.data.Dataset.from_tensor_slices(data) + if shuffle: + tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset)) + tf_dataset = tf_dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder).map(densify_ragged_batch) + return tf_dataset + + +# endregion + + +def main(): + # region Argument parsing + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + output_dir = Path(training_args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + # endregion + + # region Checkpoints + checkpoint = None + if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir: + if (output_dir / CONFIG_NAME).is_file() and (output_dir / TF2_WEIGHTS_NAME).is_file(): + checkpoint = output_dir + logger.info( + f"Checkpoint detected, resuming training from checkpoint in {training_args.output_dir}. To avoid this" + " behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + else: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to continue regardless." + ) + # endregion + + # region Logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) + + # Set the verbosity to info of the Transformers logger (on main process only): + if training_args.should_log: + transformers.utils.logging.set_verbosity_info() + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + logger.info(f"Training/evaluation parameters {training_args}") + # endregion + + # Set seed before initializing model. + set_seed(training_args.seed) + + # region Load Data + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] + datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + # endregion + + # region Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=True, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # endregion + + # region Tokenizer check: this script requires a fast tokenizer. + if not isinstance(tokenizer, PreTrainedTokenizerFast): + raise ValueError( + "This example script only works for models that have a fast tokenizer. Checkout the big table of models " + "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this " + "requirement" + ) + # endregion + + # region Preprocessing the datasets + # Preprocessing is slightly different for training and evaluation. + if training_args.do_train: + column_names = datasets["train"].column_names + elif training_args.do_eval: + column_names = datasets["validation"].column_names + else: + column_names = datasets["test"].column_names + question_column_name = "question" if "question" in column_names else column_names[0] + context_column_name = "context" if "context" in column_names else column_names[1] + answer_column_name = "answers" if "answers" in column_names else column_names[2] + + # Padding side determines if we do (question|context) or (context|question). + pad_on_right = tokenizer.padding_side == "right" + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + # Training preprocessing + def prepare_train_features(examples): + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + tokenized_examples = tokenizer( + examples[question_column_name if pad_on_right else context_column_name], + examples[context_column_name if pad_on_right else question_column_name], + truncation="only_second" if pad_on_right else "only_first", + max_length=max_seq_length, + stride=data_args.doc_stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length" if data_args.pad_to_max_length else False, + ) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") + # The offset mappings will give us a map from token to character position in the original context. This will + # help us compute the start_positions and end_positions. + offset_mapping = tokenized_examples.pop("offset_mapping") + + # Let's label those examples! + tokenized_examples["start_positions"] = [] + tokenized_examples["end_positions"] = [] + + for i, offsets in enumerate(offset_mapping): + # We will label impossible answers with the index of the CLS token. + input_ids = tokenized_examples["input_ids"][i] + cls_index = input_ids.index(tokenizer.cls_token_id) + + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_examples.sequence_ids(i) + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + answers = examples[answer_column_name][sample_index] + # If no answers are given, set the cls_index as answer. + if len(answers["answer_start"]) == 0: + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + # Start/end character index of the answer in the text. + start_char = answers["answer_start"][0] + end_char = start_char + len(answers["text"][0]) + + # Start token index of the current span in the text. + token_start_index = 0 + while sequence_ids[token_start_index] != (1 if pad_on_right else 0): + token_start_index += 1 + + # End token index of the current span in the text. + token_end_index = len(input_ids) - 1 + while sequence_ids[token_end_index] != (1 if pad_on_right else 0): + token_end_index -= 1 + + # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). + if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + # Otherwise move the token_start_index and token_end_index to the two ends of the answer. + # Note: we could go after the last offset if the answer is the last word (edge case). + while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: + token_start_index += 1 + tokenized_examples["start_positions"].append(token_start_index - 1) + while offsets[token_end_index][1] >= end_char: + token_end_index -= 1 + tokenized_examples["end_positions"].append(token_end_index + 1) + + return tokenized_examples + + processed_datasets = dict() + if training_args.do_train: + if "train" not in datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = datasets["train"] + if data_args.max_train_samples is not None: + # We will select sample from whole data if agument is specified + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + # Create train feature from dataset + train_dataset = train_dataset.map( + prepare_train_features, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + if data_args.max_train_samples is not None: + # Number of samples might increase during Feature Creation, We select only specified max samples + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + processed_datasets["train"] = train_dataset + + # Validation preprocessing + def prepare_validation_features(examples): + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + tokenized_examples = tokenizer( + examples[question_column_name if pad_on_right else context_column_name], + examples[context_column_name if pad_on_right else question_column_name], + truncation="only_second" if pad_on_right else "only_first", + max_length=max_seq_length, + stride=data_args.doc_stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length" if data_args.pad_to_max_length else False, + ) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") + + # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the + # corresponding example_id and we will store the offset mappings. + tokenized_examples["example_id"] = [] + + for i in range(len(tokenized_examples["input_ids"])): + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_examples.sequence_ids(i) + context_index = 1 if pad_on_right else 0 + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + tokenized_examples["example_id"].append(examples["id"][sample_index]) + + # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token + # position is part of the context or not. + tokenized_examples["offset_mapping"][i] = [ + (o if sequence_ids[k] == context_index else None) + for k, o in enumerate(tokenized_examples["offset_mapping"][i]) + ] + + return tokenized_examples + + if training_args.do_eval: + if "validation" not in datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_examples = datasets["validation"] + if data_args.max_eval_samples is not None: + # We will select sample from whole data + eval_examples = eval_examples.select(range(data_args.max_eval_samples)) + # Validation Feature Creation + eval_dataset = eval_examples.map( + prepare_validation_features, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + if data_args.max_eval_samples is not None: + # During Feature creation dataset samples might increase, we will select required samples again + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + processed_datasets["validation"] = eval_dataset + + if training_args.do_predict: + if "test" not in datasets: + raise ValueError("--do_predict requires a test dataset") + predict_examples = datasets["test"] + if data_args.max_predict_samples is not None: + # We will select sample from whole data + predict_examples = predict_examples.select(range(data_args.max_predict_samples)) + # Predict Feature Creation + predict_dataset = predict_examples.map( + prepare_validation_features, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + if data_args.max_predict_samples is not None: + # During Feature creation dataset samples might increase, we will select required samples again + predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) + processed_datasets["test"] = predict_dataset + # endregion + + # region Metrics and Post-processing: + def post_processing_function(examples, features, predictions, stage="eval"): + # Post-processing: we match the start logits and end logits to answers in the original context. + predictions = postprocess_qa_predictions( + examples=examples, + features=features, + predictions=predictions, + version_2_with_negative=data_args.version_2_with_negative, + n_best_size=data_args.n_best_size, + max_answer_length=data_args.max_answer_length, + null_score_diff_threshold=data_args.null_score_diff_threshold, + output_dir=training_args.output_dir, + prefix=stage, + ) + # Format the result to the format the metric expects. + if data_args.version_2_with_negative: + formatted_predictions = [ + {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() + ] + else: + formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] + + references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] + return EvalPrediction(predictions=formatted_predictions, label_ids=references) + + metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") + + def compute_metrics(p: EvalPrediction): + return metric.compute(predictions=p.predictions, references=p.label_ids) + + # endregion + + with training_args.strategy.scope(): + # region Load model + if checkpoint is None: + model_path = model_args.model_name_or_path + else: + model_path = checkpoint + model = TFAutoModelForQuestionAnswering.from_pretrained( + model_path, + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + optimizer = tf.keras.optimizers.Adam( + learning_rate=training_args.learning_rate, + beta_1=training_args.adam_beta1, + beta_2=training_args.adam_beta2, + epsilon=training_args.adam_epsilon, + clipnorm=training_args.max_grad_norm, + ) + + def dummy_loss(y_true, y_pred): + return tf.reduce_mean(y_pred) + + losses = {"loss": dummy_loss} + model.compile(optimizer=optimizer, loss=losses) + # endregion + + # region Training + if training_args.do_train: + # Make a tf.data.Dataset for this + if isinstance(training_args.strategy, tf.distribute.TPUStrategy) or data_args.pad_to_max_length: + logger.info("Padding all batches to max length because argument was set or we're on TPU.") + dataset_mode = "constant_batch" + else: + dataset_mode = "variable_batch" + training_dataset = convert_dataset_for_tensorflow( + processed_datasets["train"], + batch_size=training_args.per_device_train_batch_size, + dataset_mode=dataset_mode, + drop_remainder=True, + shuffle=True, + ) + model.fit(training_dataset, epochs=int(training_args.num_train_epochs)) + # endregion + + # region Evaluation + if training_args.do_eval: + logger.info("*** Evaluation ***") + eval_inputs = { + "input_ids": tf.ragged.constant(processed_datasets["validation"]["input_ids"]).to_tensor(), + "attention_mask": tf.ragged.constant(processed_datasets["validation"]["attention_mask"]).to_tensor(), + } + eval_predictions = model.predict(eval_inputs) + + post_processed_eval = post_processing_function( + datasets["validation"], + processed_datasets["validation"], + (eval_predictions.start_logits, eval_predictions.end_logits), + ) + metrics = compute_metrics(post_processed_eval) + logging.info("Evaluation metrics:") + for metric, value in metrics.items(): + logging.info(f"{metric}: {value:.3f}") + # endregion + + # region Prediction + if training_args.do_predict: + logger.info("*** Predict ***") + predict_inputs = { + "input_ids": tf.ragged.constant(processed_datasets["test"]["input_ids"]).to_tensor(), + "attention_mask": tf.ragged.constant(processed_datasets["test"]["attention_mask"]).to_tensor(), + } + test_predictions = model.predict(predict_inputs) + post_processed_test = post_processing_function( + datasets["test"], + processed_datasets["test"], + (test_predictions.start_logits, test_predictions.end_logits), + ) + metrics = compute_metrics(post_processed_test) + + logging.info("Test metrics:") + for metric, value in metrics.items(): + logging.info(f"{metric}: {value:.3f}") + # endregion + + if training_args.push_to_hub: + model.push_to_hub() + + +if __name__ == "__main__": + main() diff --git a/examples/tensorflow/question-answering/run_tf_squad.py b/examples/tensorflow/question-answering/run_tf_squad.py deleted file mode 100755 index 20723f70e8..0000000000 --- a/examples/tensorflow/question-answering/run_tf_squad.py +++ /dev/null @@ -1,255 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Fine-tuning the library models for question-answering.""" - - -import logging -import os -from dataclasses import dataclass, field -from typing import Optional - -import tensorflow as tf - -from transformers import ( - AutoConfig, - AutoTokenizer, - HfArgumentParser, - TFAutoModelForQuestionAnswering, - TFTrainer, - TFTrainingArguments, - squad_convert_examples_to_features, -) -from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor -from transformers.utils import logging as hf_logging - - -hf_logging.set_verbosity_info() -hf_logging.enable_default_handler() -hf_logging.enable_explicit_format() - - -logger = logging.getLogger(__name__) - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - model_name_or_path: str = field( - metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} - ) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."}) - # If you want to tweak more attributes on your tokenizer, you should do it in a distinct script, - # or just modify its tokenizer_config.json. - cache_dir: Optional[str] = field( - default=None, - metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, - ) - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - """ - - data_dir: Optional[str] = field( - default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."} - ) - use_tfds: Optional[bool] = field(default=True, metadata={"help": "If TFDS should be used or not."}) - max_seq_length: int = field( - default=128, - metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - }, - ) - doc_stride: int = field( - default=128, - metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."}, - ) - max_query_length: int = field( - default=64, - metadata={ - "help": "The maximum number of tokens for the question. Questions longer than this will " - "be truncated to this length." - }, - ) - max_answer_length: int = field( - default=30, - metadata={ - "help": "The maximum length of an answer that can be generated. This is needed because the start " - "and end predictions are not conditioned on one another." - }, - ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} - ) - version_2_with_negative: bool = field( - default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."} - ) - null_score_diff_threshold: float = field( - default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."} - ) - n_best_size: int = field( - default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."} - ) - lang_id: int = field( - default=0, - metadata={ - "help": "language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)" - }, - ) - - -def main(): - # See all possible arguments in src/transformers/training_args.py - # or by passing the --help flag to this script. - # We now keep distinct sets of args, for a cleaner separation of concerns. - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments)) - model_args, data_args, training_args = parser.parse_args_into_dataclasses() - - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." - ) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info( - f"n_replicas: {training_args.n_replicas}, distributed training: {bool(training_args.n_replicas > 1)}, " - f"16-bits training: {training_args.fp16}" - ) - logger.info(f"Training/evaluation parameters {training_args}") - - # Prepare Question-Answering task - # Load pretrained model and tokenizer - # - # Distributed training: - # The .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - use_fast=model_args.use_fast, - ) - - with training_args.strategy.scope(): - model = TFAutoModelForQuestionAnswering.from_pretrained( - model_args.model_name_or_path, - from_pt=bool(".bin" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - ) - - # Get datasets - if data_args.use_tfds: - if data_args.version_2_with_negative: - logger.warning("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically") - - try: - import tensorflow_datasets as tfds - except ImportError: - raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.") - - tfds_examples = tfds.load("squad", data_dir=data_args.data_dir) - train_examples = ( - SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False) - if training_args.do_train - else None - ) - eval_examples = ( - SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=True) - if training_args.do_eval - else None - ) - else: - processor = SquadV2Processor() if data_args.version_2_with_negative else SquadV1Processor() - train_examples = processor.get_train_examples(data_args.data_dir) if training_args.do_train else None - eval_examples = processor.get_dev_examples(data_args.data_dir) if training_args.do_eval else None - - train_dataset = ( - squad_convert_examples_to_features( - examples=train_examples, - tokenizer=tokenizer, - max_seq_length=data_args.max_seq_length, - doc_stride=data_args.doc_stride, - max_query_length=data_args.max_query_length, - is_training=True, - return_dataset="tf", - ) - if training_args.do_train - else None - ) - - train_dataset = train_dataset.apply(tf.data.experimental.assert_cardinality(len(train_examples))) - - eval_dataset = ( - squad_convert_examples_to_features( - examples=eval_examples, - tokenizer=tokenizer, - max_seq_length=data_args.max_seq_length, - doc_stride=data_args.doc_stride, - max_query_length=data_args.max_query_length, - is_training=False, - return_dataset="tf", - ) - if training_args.do_eval - else None - ) - - eval_dataset = eval_dataset.apply(tf.data.experimental.assert_cardinality(len(eval_examples))) - - # Initialize our Trainer - trainer = TFTrainer( - model=model, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - ) - - # Training - if training_args.do_train: - trainer.train() - trainer.save_model() - tokenizer.save_pretrained(training_args.output_dir) - - -if __name__ == "__main__": - main() diff --git a/examples/tensorflow/question-answering/utils_qa.py b/examples/tensorflow/question-answering/utils_qa.py new file mode 100644 index 0000000000..36d911b9e9 --- /dev/null +++ b/examples/tensorflow/question-answering/utils_qa.py @@ -0,0 +1,425 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Post-processing utilities for question answering. +""" +import collections +import json +import logging +import os +from typing import Optional, Tuple + +import numpy as np +from tqdm.auto import tqdm + + +logger = logging.getLogger(__name__) + + +def postprocess_qa_predictions( + examples, + features, + predictions: Tuple[np.ndarray, np.ndarray], + version_2_with_negative: bool = False, + n_best_size: int = 20, + max_answer_length: int = 30, + null_score_diff_threshold: float = 0.0, + output_dir: Optional[str] = None, + prefix: Optional[str] = None, +): + """ + Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the + original contexts. This is the base postprocessing functions for models that only return start and end logits. + + Args: + examples: The non-preprocessed dataset (see the main script for more information). + features: The processed dataset (see the main script for more information). + predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): + The predictions of the model: two arrays containing the start logits and the end logits respectively. Its + first dimension must match the number of elements of :obj:`features`. + version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the underlying dataset contains examples with no answers. + n_best_size (:obj:`int`, `optional`, defaults to 20): + The total number of n-best predictions to generate when looking for an answer. + max_answer_length (:obj:`int`, `optional`, defaults to 30): + The maximum length of an answer that can be generated. This is needed because the start and end predictions + are not conditioned on one another. + null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0): + The threshold used to select the null answer: if the best answer has a score that is less than the score of + the null answer minus this threshold, the null answer is selected for this example (note that the score of + the null answer for an example giving several features is the minimum of the scores for the null answer on + each feature: all features must be aligned on the fact they `want` to predict a null answer). + + Only useful when :obj:`version_2_with_negative` is :obj:`True`. + output_dir (:obj:`str`, `optional`): + If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if + :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null + answers, are saved in `output_dir`. + prefix (:obj:`str`, `optional`): + If provided, the dictionaries mentioned above are saved with `prefix` added to their names. + is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether this process is the main process or not (used to determine if logging/saves should be done). + """ + assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)." + all_start_logits, all_end_logits = predictions + + assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features." + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + features_per_example = collections.defaultdict(list) + for i, feature in enumerate(features): + features_per_example[example_id_to_index[feature["example_id"]]].append(i) + + # The dictionaries we have to fill. + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + if version_2_with_negative: + scores_diff_json = collections.OrderedDict() + + # Logging. + logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") + + # Let's loop over all the examples! + for example_index, example in enumerate(tqdm(examples)): + # Those are the indices of the features associated to the current example. + feature_indices = features_per_example[example_index] + + min_null_prediction = None + prelim_predictions = [] + + # Looping through all the features associated to the current example. + for feature_index in feature_indices: + # We grab the predictions of the model for this feature. + start_logits = all_start_logits[feature_index] + end_logits = all_end_logits[feature_index] + # This is what will allow us to map some the positions in our logits to span of texts in the original + # context. + offset_mapping = features[feature_index]["offset_mapping"] + # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context + # available in the current feature. + token_is_max_context = features[feature_index].get("token_is_max_context", None) + + # Update minimum null prediction. + feature_null_score = start_logits[0] + end_logits[0] + if min_null_prediction is None or min_null_prediction["score"] > feature_null_score: + min_null_prediction = { + "offsets": (0, 0), + "score": feature_null_score, + "start_logit": start_logits[0], + "end_logit": end_logits[0], + } + + # Go through all possibilities for the `n_best_size` greater start and end logits. + start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() + end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() + for start_index in start_indexes: + for end_index in end_indexes: + # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond + # to part of the input_ids that are not in the context. + if ( + start_index >= len(offset_mapping) + or end_index >= len(offset_mapping) + or offset_mapping[start_index] is None + or offset_mapping[end_index] is None + ): + continue + # Don't consider answers with a length that is either < 0 or > max_answer_length. + if end_index < start_index or end_index - start_index + 1 > max_answer_length: + continue + # Don't consider answer that don't have the maximum context available (if such information is + # provided). + if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): + continue + prelim_predictions.append( + { + "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), + "score": start_logits[start_index] + end_logits[end_index], + "start_logit": start_logits[start_index], + "end_logit": end_logits[end_index], + } + ) + if version_2_with_negative: + # Add the minimum null prediction + prelim_predictions.append(min_null_prediction) + null_score = min_null_prediction["score"] + + # Only keep the best `n_best_size` predictions. + predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] + + # Add back the minimum null prediction if it was removed because of its low score. + if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions): + predictions.append(min_null_prediction) + + # Use the offsets to gather the answer text in the original context. + context = example["context"] + for pred in predictions: + offsets = pred.pop("offsets") + pred["text"] = context[offsets[0] : offsets[1]] + + # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid + # failure. + if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): + predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}) + + # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using + # the LogSumExp trick). + scores = np.array([pred.pop("score") for pred in predictions]) + exp_scores = np.exp(scores - np.max(scores)) + probs = exp_scores / exp_scores.sum() + + # Include the probabilities in our predictions. + for prob, pred in zip(probs, predictions): + pred["probability"] = prob + + # Pick the best prediction. If the null answer is not possible, this is easy. + if not version_2_with_negative: + all_predictions[example["id"]] = predictions[0]["text"] + else: + # Otherwise we first need to find the best non-empty prediction. + i = 0 + while predictions[i]["text"] == "": + i += 1 + best_non_null_pred = predictions[i] + + # Then we compare to the null prediction using the threshold. + score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"] + scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable. + if score_diff > null_score_diff_threshold: + all_predictions[example["id"]] = "" + else: + all_predictions[example["id"]] = best_non_null_pred["text"] + + # Make `predictions` JSON-serializable by casting np.float back to float. + all_nbest_json[example["id"]] = [ + {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} + for pred in predictions + ] + + # If we have an output_dir, let's save all those dicts. + if output_dir is not None: + assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + + prediction_file = os.path.join( + output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" + ) + nbest_file = os.path.join( + output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" + ) + if version_2_with_negative: + null_odds_file = os.path.join( + output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" + ) + + logger.info(f"Saving predictions to {prediction_file}.") + with open(prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + logger.info(f"Saving nbest_preds to {nbest_file}.") + with open(nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + if version_2_with_negative: + logger.info(f"Saving null_odds to {null_odds_file}.") + with open(null_odds_file, "w") as writer: + writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + + return all_predictions + + +def postprocess_qa_predictions_with_beam_search( + examples, + features, + predictions: Tuple[np.ndarray, np.ndarray], + version_2_with_negative: bool = False, + n_best_size: int = 20, + max_answer_length: int = 30, + start_n_top: int = 5, + end_n_top: int = 5, + output_dir: Optional[str] = None, + prefix: Optional[str] = None, + is_world_process_zero: bool = True, +): + """ + Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the + original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as + cls token predictions. + + Args: + examples: The non-preprocessed dataset (see the main script for more information). + features: The processed dataset (see the main script for more information). + predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): + The predictions of the model: two arrays containing the start logits and the end logits respectively. Its + first dimension must match the number of elements of :obj:`features`. + version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the underlying dataset contains examples with no answers. + n_best_size (:obj:`int`, `optional`, defaults to 20): + The total number of n-best predictions to generate when looking for an answer. + max_answer_length (:obj:`int`, `optional`, defaults to 30): + The maximum length of an answer that can be generated. This is needed because the start and end predictions + are not conditioned on one another. + start_n_top (:obj:`int`, `optional`, defaults to 5): + The number of top start logits too keep when searching for the :obj:`n_best_size` predictions. + end_n_top (:obj:`int`, `optional`, defaults to 5): + The number of top end logits too keep when searching for the :obj:`n_best_size` predictions. + output_dir (:obj:`str`, `optional`): + If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if + :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null + answers, are saved in `output_dir`. + prefix (:obj:`str`, `optional`): + If provided, the dictionaries mentioned above are saved with `prefix` added to their names. + is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether this process is the main process or not (used to determine if logging/saves should be done). + """ + assert len(predictions) == 5, "`predictions` should be a tuple with five elements." + start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions + + assert len(predictions[0]) == len( + features + ), f"Got {len(predictions[0])} predicitions and {len(features)} features." + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + features_per_example = collections.defaultdict(list) + for i, feature in enumerate(features): + features_per_example[example_id_to_index[feature["example_id"]]].append(i) + + # The dictionaries we have to fill. + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + scores_diff_json = collections.OrderedDict() if version_2_with_negative else None + + # Logging. + logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN) + logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") + + # Let's loop over all the examples! + for example_index, example in enumerate(tqdm(examples)): + # Those are the indices of the features associated to the current example. + feature_indices = features_per_example[example_index] + + min_null_score = None + prelim_predictions = [] + + # Looping through all the features associated to the current example. + for feature_index in feature_indices: + # We grab the predictions of the model for this feature. + start_log_prob = start_top_log_probs[feature_index] + start_indexes = start_top_index[feature_index] + end_log_prob = end_top_log_probs[feature_index] + end_indexes = end_top_index[feature_index] + feature_null_score = cls_logits[feature_index] + # This is what will allow us to map some the positions in our logits to span of texts in the original + # context. + offset_mapping = features[feature_index]["offset_mapping"] + # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context + # available in the current feature. + token_is_max_context = features[feature_index].get("token_is_max_context", None) + + # Update minimum null prediction + if min_null_score is None or feature_null_score < min_null_score: + min_null_score = feature_null_score + + # Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits. + for i in range(start_n_top): + for j in range(end_n_top): + start_index = int(start_indexes[i]) + j_index = i * end_n_top + j + end_index = int(end_indexes[j_index]) + # Don't consider out-of-scope answers (last part of the test should be unnecessary because of the + # p_mask but let's not take any risk) + if ( + start_index >= len(offset_mapping) + or end_index >= len(offset_mapping) + or offset_mapping[start_index] is None + or offset_mapping[end_index] is None + ): + continue + # Don't consider answers with a length negative or > max_answer_length. + if end_index < start_index or end_index - start_index + 1 > max_answer_length: + continue + # Don't consider answer that don't have the maximum context available (if such information is + # provided). + if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): + continue + prelim_predictions.append( + { + "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), + "score": start_log_prob[i] + end_log_prob[j_index], + "start_log_prob": start_log_prob[i], + "end_log_prob": end_log_prob[j_index], + } + ) + + # Only keep the best `n_best_size` predictions. + predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] + + # Use the offsets to gather the answer text in the original context. + context = example["context"] + for pred in predictions: + offsets = pred.pop("offsets") + pred["text"] = context[offsets[0] : offsets[1]] + + # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid + # failure. + if len(predictions) == 0: + predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6}) + + # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using + # the LogSumExp trick). + scores = np.array([pred.pop("score") for pred in predictions]) + exp_scores = np.exp(scores - np.max(scores)) + probs = exp_scores / exp_scores.sum() + + # Include the probabilities in our predictions. + for prob, pred in zip(probs, predictions): + pred["probability"] = prob + + # Pick the best prediction and set the probability for the null answer. + all_predictions[example["id"]] = predictions[0]["text"] + if version_2_with_negative: + scores_diff_json[example["id"]] = float(min_null_score) + + # Make `predictions` JSON-serializable by casting np.float back to float. + all_nbest_json[example["id"]] = [ + {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} + for pred in predictions + ] + + # If we have an output_dir, let's save all those dicts. + if output_dir is not None: + assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + + prediction_file = os.path.join( + output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" + ) + nbest_file = os.path.join( + output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" + ) + if version_2_with_negative: + null_odds_file = os.path.join( + output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" + ) + + print(f"Saving predictions to {prediction_file}.") + with open(prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + print(f"Saving nbest_preds to {nbest_file}.") + with open(nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + if version_2_with_negative: + print(f"Saving null_odds to {null_odds_file}.") + with open(null_odds_file, "w") as writer: + writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + + return all_predictions, scores_diff_json