From 78807d86eb8791bdbe149adff62af11f778a9267 Mon Sep 17 00:00:00 2001 From: Kamal Raj Date: Tue, 21 Sep 2021 18:34:48 +0530 Subject: [PATCH] [FLAX] Question Answering Example (#13649) * flax qa example * Updated README: Added Large model * added utils_qa.py FULL_COPIES * Updates: 1. Copyright Year updated 2. added dtype arg 3. passing seed and dtype to load model 4. Check eval flag before running eval * updated README * updated code comment --- examples/flax/question-answering/README.md | 128 +++ .../flax/question-answering/requirements.txt | 5 + examples/flax/question-answering/run_qa.py | 905 ++++++++++++++++++ examples/flax/question-answering/utils_qa.py | 427 +++++++++ utils/check_copies.py | 5 +- 5 files changed, 1469 insertions(+), 1 deletion(-) create mode 100644 examples/flax/question-answering/README.md create mode 100644 examples/flax/question-answering/requirements.txt create mode 100644 examples/flax/question-answering/run_qa.py create mode 100644 examples/flax/question-answering/utils_qa.py diff --git a/examples/flax/question-answering/README.md b/examples/flax/question-answering/README.md new file mode 100644 index 0000000000..6b8360349e --- /dev/null +++ b/examples/flax/question-answering/README.md @@ -0,0 +1,128 @@ + + +# Question Answering examples + +Based on the script [`run_qa.py`](https://github.com/huggingface/transformers/blob/master/examples/flax/question-answering/run_qa.py). + +**Note:** This script only works with models that have a fast tokenizer (backed by the 🤗 Tokenizers library) as it +uses special features of those tokenizers. You can check if your favorite model has a fast tokenizer in +[this table](https://huggingface.co/transformers/index.html#supported-frameworks), if it doesn't you can still use the old version +of the script. + + +The following example fine-tunes BERT on SQuAD: + +To begin with it is recommended to create a model repository to save the trained model and logs. +Here we call the model `"bert-qa-squad-test"`, but you can change the model name as you like. + +You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that +you are logged in) or via the command line: + +``` +huggingface-cli repo create bert-qa-squad-test +``` + +Next we clone the model repository to add the tokenizer and model files. + +``` +git clone https://huggingface.co//bert-qa-squad-test +``` + +Great, we have set up our model repository. During training, we will automatically +push the training logs and model weights to the repo. + +Next, let's add a symbolic link to the `run_qa.py`. + +```bash +export MODEL_DIR="./bert-qa-squad-test" +ln -s ~/transformers/examples/flax/question-answering/run_qa.py run_qa.py +``` + +```bash +python run_qa.py \ + --model_name_or_path bert-base-uncased \ + --dataset_name squad \ + --do_train \ + --do_eval \ + --max_seq_length 384 \ + --doc_stride 128 \ + --learning_rate 3e-5 \ + --num_train_epochs 2 \ + --per_device_train_batch_size 12 \ + --output_dir ${MODEL_DIR} \ + --eval_steps 1000 \ + --push_to_hub +``` + +Using the command above, the script will train for 2 epochs and run eval after each epoch. +Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`. +You can see the results by running `tensorboard` in that directory: + +```bash +$ tensorboard --logdir . +``` + +or directly on the hub under *Training metrics*. + +Training with the previously defined hyper-parameters yields the following results: + +```bash +f1 = 88.62 +exact_match = 81.34 +``` + +sample Metrics - [tfhub.dev](https://tensorboard.dev/experiment/6gU75Hx8TGCnc6tr4ZgI9Q) + +Here is an example training on 4 TITAN RTX GPUs and Bert Whole Word Masking uncased model to reach a F1 > 93 on SQuAD1.1: + +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python run_qa.py \ +--model_name_or_path bert-large-uncased-whole-word-masking \ +--dataset_name squad \ +--do_train \ +--do_eval \ +--per_device_train_batch_size 6 \ +--learning_rate 3e-5 \ +--num_train_epochs 2 \ +--max_seq_length 384 \ +--doc_stride 128 \ +--output_dir /tmp/wwm_uncased_finetuned_squad/ \ +--eval_steps 1000 +``` + +Training with the previously defined hyper-parameters yields the following results: + +```bash +f1 = 93.31 +exact_match = 87.04 +``` + + +### Usage notes + +Note that when contexts are long they may be split into multiple training cases, not all of which may contain +the answer span. + +As-is, the example script will train on SQuAD or any other question-answering dataset formatted the same way, and can handle user +inputs as well. + +### Memory usage and data loading + +One thing to note is that all data is loaded into memory in this script. Most question answering datasets are small +enough that this is not an issue, but if you have a very large dataset you will need to modify the script to handle +data streaming. diff --git a/examples/flax/question-answering/requirements.txt b/examples/flax/question-answering/requirements.txt new file mode 100644 index 0000000000..12790f285d --- /dev/null +++ b/examples/flax/question-answering/requirements.txt @@ -0,0 +1,5 @@ +datasets >= 1.8.0 +jax>=0.2.17 +jaxlib>=0.1.68 +flax>=0.3.4 +optax>=0.0.8 \ No newline at end of file diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py new file mode 100644 index 0000000000..f45bcfc21e --- /dev/null +++ b/examples/flax/question-answering/run_qa.py @@ -0,0 +1,905 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for question answering. +""" +# You can also adapt this script on your own question answering task. Pointers for this are left as comments. + +import logging +import os +import random +import sys +import time +from dataclasses import dataclass, field +from itertools import chain +from typing import Any, Callable, Dict, Optional, Tuple + +import datasets +import numpy as np +from datasets import load_dataset, load_metric +from tqdm import tqdm + +import jax +import jax.numpy as jnp +import optax +import transformers +from flax import struct, traverse_util +from flax.jax_utils import replicate, unreplicate +from flax.metrics import tensorboard +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard +from transformers import ( + AutoConfig, + AutoTokenizer, + EvalPrediction, + FlaxAutoModelForQuestionAnswering, + HfArgumentParser, + PreTrainedTokenizerFast, + TrainingArguments, +) +from transformers.utils import check_min_version +from utils_qa import postprocess_qa_predictions + + +logger = logging.getLogger(__name__) + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.11.0.dev0") + +Array = Any +Dataset = datasets.arrow_dataset.Dataset +PRNGKey = Any + + +# region Arguments +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + dtype: Optional[str] = field( + default="float32", + metadata={ + "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`." + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + test_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."}, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_seq_length: int = field( + default=384, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can " + "be faster on GPU but will be slower on TPU)." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + version_2_with_negative: bool = field( + default=False, metadata={"help": "If true, some of the examples do not have an answer."} + ) + null_score_diff_threshold: float = field( + default=0.0, + metadata={ + "help": "The threshold used to select the null answer: if the best answer has a score that is less than " + "the score of the null answer minus this threshold, the null answer is selected for this example. " + "Only useful when `version_2_with_negative=True`." + }, + ) + doc_stride: int = field( + default=128, + metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."}, + ) + n_best_size: int = field( + default=20, + metadata={"help": "The total number of n-best predictions to generate when looking for an answer."}, + ) + max_answer_length: int = field( + default=30, + metadata={ + "help": "The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another." + }, + ) + + def __post_init__(self): + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + and self.test_file is None + ): + raise ValueError("Need either a dataset name or a training/validation file/test_file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if self.test_file is not None: + extension = self.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." + + +# endregion + +# region Create a train state +def create_train_state( + model: FlaxAutoModelForQuestionAnswering, + learning_rate_fn: Callable[[int], float], + num_labels: int, + training_args: TrainingArguments, +) -> train_state.TrainState: + """Create initial training state.""" + + class TrainState(train_state.TrainState): + """Train state with an Optax optimizer. + + The two functions below differ depending on whether the task is classification + or regression. + + Args: + logits_fn: Applied to last layer to obtain the logits. + loss_fn: Function to compute the loss. + """ + + logits_fn: Callable = struct.field(pytree_node=False) + loss_fn: Callable = struct.field(pytree_node=False) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxBERT-like models. + # For other models, one should correct the layer norm parameter naming + # accordingly. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + + tx = optax.adamw( + learning_rate=learning_rate_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + def cross_entropy_loss(logits, labels): + start_loss = optax.softmax_cross_entropy(logits[0], onehot(labels[0], num_classes=num_labels)) + end_loss = optax.softmax_cross_entropy(logits[1], onehot(labels[1], num_classes=num_labels)) + xentropy = (start_loss + end_loss) / 2.0 + return jnp.mean(xentropy) + + return TrainState.create( + apply_fn=model.__call__, + params=model.params, + tx=tx, + logits_fn=lambda logits: logits, + loss_fn=cross_entropy_loss, + ) + + +# endregion + + +# region Create learning rate function +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +# endregion + +# region train data iterator +def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int): + """Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices.""" + steps_per_epoch = len(dataset) // batch_size + perms = jax.random.permutation(rng, len(dataset)) + perms = perms[: steps_per_epoch * batch_size] # Skip incomplete batch. + perms = perms.reshape((steps_per_epoch, batch_size)) + + for perm in perms: + batch = dataset[perm] + batch = {k: np.array(v) for k, v in batch.items()} + batch = shard(batch) + + yield batch + + +# endregion + +# region eval data iterator +def eval_data_collator(dataset: Dataset, batch_size: int): + """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices.""" + for i in range(len(dataset) // batch_size): + batch = dataset[i * batch_size : (i + 1) * batch_size] + batch = {k: np.array(v) for k, v in batch.items()} + batch = shard(batch) + + yield batch + + +# endregion + + +def main(): + # region Argument parsing + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + # endregion + + # region Logging + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + # endregion + + # region Load Data + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset( + data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir + ) + else: + # Loading the dataset from local csv or json file. + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] + raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + # endregion + + # region Load pretrained model and tokenizer + # + # Load pretrained model and tokenizer + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=True, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # endregion + + # region Tokenizer check: this script requires a fast tokenizer. + if not isinstance(tokenizer, PreTrainedTokenizerFast): + raise ValueError( + "This example script only works for models that have a fast tokenizer. Checkout the big table of models " + "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this " + "requirement" + ) + # endregion + + # region Preprocessing the datasets + # Preprocessing is slightly different for training and evaluation. + if training_args.do_train: + column_names = raw_datasets["train"].column_names + elif training_args.do_eval: + column_names = raw_datasets["validation"].column_names + else: + column_names = raw_datasets["test"].column_names + question_column_name = "question" if "question" in column_names else column_names[0] + context_column_name = "context" if "context" in column_names else column_names[1] + answer_column_name = "answers" if "answers" in column_names else column_names[2] + + # Padding side determines if we do (question|context) or (context|question). + pad_on_right = tokenizer.padding_side == "right" + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + # Training preprocessing + def prepare_train_features(examples): + # Some of the questions have lots of whitespace on the left, which is not useful and will make the + # truncation of the context fail (the tokenized question will take a lots of space). So we remove that + # left whitespace + examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]] + + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + tokenized_examples = tokenizer( + examples[question_column_name if pad_on_right else context_column_name], + examples[context_column_name if pad_on_right else question_column_name], + truncation="only_second" if pad_on_right else "only_first", + max_length=max_seq_length, + stride=data_args.doc_stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length", + ) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") + # The offset mappings will give us a map from token to character position in the original context. This will + # help us compute the start_positions and end_positions. + offset_mapping = tokenized_examples.pop("offset_mapping") + + # Let's label those examples! + tokenized_examples["start_positions"] = [] + tokenized_examples["end_positions"] = [] + + for i, offsets in enumerate(offset_mapping): + # We will label impossible answers with the index of the CLS token. + input_ids = tokenized_examples["input_ids"][i] + cls_index = input_ids.index(tokenizer.cls_token_id) + + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_examples.sequence_ids(i) + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + answers = examples[answer_column_name][sample_index] + # If no answers are given, set the cls_index as answer. + if len(answers["answer_start"]) == 0: + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + # Start/end character index of the answer in the text. + start_char = answers["answer_start"][0] + end_char = start_char + len(answers["text"][0]) + + # Start token index of the current span in the text. + token_start_index = 0 + while sequence_ids[token_start_index] != (1 if pad_on_right else 0): + token_start_index += 1 + + # End token index of the current span in the text. + token_end_index = len(input_ids) - 1 + while sequence_ids[token_end_index] != (1 if pad_on_right else 0): + token_end_index -= 1 + + # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). + if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + # Otherwise move the token_start_index and token_end_index to the two ends of the answer. + # Note: we could go after the last offset if the answer is the last word (edge case). + while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: + token_start_index += 1 + tokenized_examples["start_positions"].append(token_start_index - 1) + while offsets[token_end_index][1] >= end_char: + token_end_index -= 1 + tokenized_examples["end_positions"].append(token_end_index + 1) + + return tokenized_examples + + processed_raw_datasets = dict() + if training_args.do_train: + if "train" not in raw_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = raw_datasets["train"] + if data_args.max_train_samples is not None: + # We will select sample from whole data if agument is specified + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + # Create train feature from dataset + train_dataset = train_dataset.map( + prepare_train_features, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + if data_args.max_train_samples is not None: + # Number of samples might increase during Feature Creation, We select only specified max samples + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + processed_raw_datasets["train"] = train_dataset + + # Validation preprocessing + def prepare_validation_features(examples): + # Some of the questions have lots of whitespace on the left, which is not useful and will make the + # truncation of the context fail (the tokenized question will take a lots of space). So we remove that + # left whitespace + examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]] + + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + tokenized_examples = tokenizer( + examples[question_column_name if pad_on_right else context_column_name], + examples[context_column_name if pad_on_right else question_column_name], + truncation="only_second" if pad_on_right else "only_first", + max_length=max_seq_length, + stride=data_args.doc_stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length", + ) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") + + # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the + # corresponding example_id and we will store the offset mappings. + tokenized_examples["example_id"] = [] + + for i in range(len(tokenized_examples["input_ids"])): + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_examples.sequence_ids(i) + context_index = 1 if pad_on_right else 0 + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + tokenized_examples["example_id"].append(examples["id"][sample_index]) + + # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token + # position is part of the context or not. + tokenized_examples["offset_mapping"][i] = [ + (o if sequence_ids[k] == context_index else None) + for k, o in enumerate(tokenized_examples["offset_mapping"][i]) + ] + + return tokenized_examples + + if training_args.do_eval: + if "validation" not in raw_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_examples = raw_datasets["validation"] + if data_args.max_eval_samples is not None: + # We will select sample from whole data + eval_examples = eval_examples.select(range(data_args.max_eval_samples)) + # Validation Feature Creation + eval_dataset = eval_examples.map( + prepare_validation_features, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + if data_args.max_eval_samples is not None: + # During Feature creation dataset samples might increase, we will select required samples again + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + processed_raw_datasets["validation"] = eval_dataset + + if training_args.do_predict: + if "test" not in raw_datasets: + raise ValueError("--do_predict requires a test dataset") + predict_examples = raw_datasets["test"] + if data_args.max_predict_samples is not None: + # We will select sample from whole data + predict_examples = predict_examples.select(range(data_args.max_predict_samples)) + # Predict Feature Creation + predict_dataset = predict_examples.map( + prepare_validation_features, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + if data_args.max_predict_samples is not None: + # During Feature creation dataset samples might increase, we will select required samples again + predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) + processed_raw_datasets["test"] = predict_dataset + # endregion + + # region Metrics and Post-processing: + def post_processing_function(examples, features, predictions, stage="eval"): + # Post-processing: we match the start logits and end logits to answers in the original context. + predictions = postprocess_qa_predictions( + examples=examples, + features=features, + predictions=predictions, + version_2_with_negative=data_args.version_2_with_negative, + n_best_size=data_args.n_best_size, + max_answer_length=data_args.max_answer_length, + null_score_diff_threshold=data_args.null_score_diff_threshold, + output_dir=training_args.output_dir, + prefix=stage, + ) + # Format the result to the format the metric expects. + if data_args.version_2_with_negative: + formatted_predictions = [ + {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() + ] + else: + formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] + + references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] + return EvalPrediction(predictions=formatted_predictions, label_ids=references) + + metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") + + def compute_metrics(p: EvalPrediction): + return metric.compute(predictions=p.predictions, references=p.label_ids) + + # Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor + def create_and_fill_np_array(start_or_end_logits, dataset, max_len): + """ + Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor + + Args: + start_or_end_logits(:obj:`tensor`): + This is the output predictions of the model. We can only enter either start or end logits. + eval_dataset: Evaluation dataset + max_len(:obj:`int`): + The maximum length of the output tensor. ( See the model.eval() part for more details ) + """ + + step = 0 + # create a numpy array and fill it with -100. + logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float64) + # Now since we have create an array now we will populate it with the outputs of the model. + for i, output_logit in enumerate(start_or_end_logits): # populate columns + # We have to fill it such that we have to take the whole tensor and replace it on the newly created array + # And after every iteration we have to change the step + + batch_size = output_logit.shape[0] + cols = output_logit.shape[1] + + if step + batch_size < len(dataset): + logits_concat[step : step + batch_size, :cols] = output_logit + else: + logits_concat[step:, :cols] = output_logit[: len(dataset) - step] + + step += batch_size + + return logits_concat + + # endregion + + # region Training steps and logging init + train_dataset = processed_raw_datasets["train"] + eval_dataset = processed_raw_datasets["validation"] + + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # Define a summary writer + summary_writer = tensorboard.SummaryWriter(training_args.output_dir) + summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)}) + + def write_train_metric(summary_writer, train_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + def write_eval_metric(summary_writer, eval_metrics, step): + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + num_epochs = int(training_args.num_train_epochs) + rng = jax.random.PRNGKey(training_args.seed) + dropout_rngs = jax.random.split(rng, jax.local_device_count()) + + train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count() + eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count() + # endregion + + # region Load model + model = FlaxAutoModelForQuestionAnswering.from_pretrained( + model_args.model_name_or_path, + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + seed=training_args.seed, + dtype=getattr(jnp, model_args.dtype), + ) + + learning_rate_fn = create_learning_rate_fn( + len(train_dataset), + train_batch_size, + training_args.num_train_epochs, + training_args.warmup_steps, + training_args.learning_rate, + ) + + state = create_train_state(model, learning_rate_fn, num_labels=max_seq_length, training_args=training_args) + # endregion + + # region Define train step functions + def train_step( + state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey + ) -> Tuple[train_state.TrainState, float]: + """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" + dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) + start_positions = batch.pop("start_positions") + end_positions = batch.pop("end_positions") + targets = (start_positions, end_positions) + + def loss_fn(params): + logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True) + loss = state.loss_fn(logits, targets) + return loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grad = grad_fn(state.params) + grad = jax.lax.pmean(grad, "batch") + new_state = state.apply_gradients(grads=grad) + metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") + return new_state, metrics, new_dropout_rng + + p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) + # endregion + + # region Define eval step functions + def eval_step(state, batch): + logits = state.apply_fn(**batch, params=state.params, train=False) + return state.logits_fn(logits) + + p_eval_step = jax.pmap(eval_step, axis_name="batch") + # endregion + + # region Define train and eval loop + logger.info(f"===== Starting training ({num_epochs} epochs) =====") + train_time = 0 + + # make sure weights are replicated on each device + state = replicate(state) + + train_time = 0 + step_per_epoch = len(train_dataset) // train_batch_size + total_steps = step_per_epoch * num_epochs + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: + + train_start = time.time() + train_metrics = [] + + # Create sampling rng + rng, input_rng = jax.random.split(rng) + + # train + for step, batch in enumerate( + tqdm( + train_data_collator(input_rng, train_dataset, train_batch_size), + total=step_per_epoch, + desc="Training...", + position=1, + ), + 1, + ): + state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs) + train_metrics.append(train_metric) + + cur_step = epoch * step_per_epoch + step + + if cur_step % training_args.logging_steps == 0 and cur_step > 0: + # Save metrics + train_metric = unreplicate(train_metric) + train_time += time.time() - train_start + if jax.process_index() == 0: + write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" + ) + + train_metrics = [] + + if ( + training_args.do_eval + and (cur_step % training_args.eval_steps == 0 or cur_step % step_per_epoch == 0) + and cur_step > 0 + ): + + eval_metrics = {} + all_start_logits = [] + all_end_logits = [] + # evaluate + for batch in tqdm( + eval_data_collator(eval_dataset, eval_batch_size), + total=len(eval_dataset) // eval_batch_size, + desc="Evaluating ...", + position=2, + ): + _ = batch.pop("example_id") + _ = batch.pop("offset_mapping") + predictions = p_eval_step(state, batch) + start_logits = np.array([pred for pred in chain(*predictions[0])]) + end_logits = np.array([pred for pred in chain(*predictions[1])]) + all_start_logits.append(start_logits) + all_end_logits.append(end_logits) + + # evaluate also on leftover examples (not divisible by batch_size) + num_leftover_samples = len(eval_dataset) % eval_batch_size + + # make sure leftover batch is evaluated on one device + if num_leftover_samples > 0 and jax.process_index() == 0: + # take leftover samples + batch = eval_dataset[-num_leftover_samples:] + batch = {k: np.array(v) for k, v in batch.items()} + _ = batch.pop("example_id") + _ = batch.pop("offset_mapping") + + predictions = eval_step(unreplicate(state), batch) + start_logits = np.array([pred for pred in predictions[0]]) + end_logits = np.array([pred for pred in predictions[1]]) + all_start_logits.append(start_logits) + all_end_logits.append(end_logits) + + max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor + + # concatenate the numpy array + start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len) + end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len) + + # delete the list of numpy arrays + del all_start_logits + del all_end_logits + outputs_numpy = (start_logits_concat, end_logits_concat) + prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy) + eval_metrics = compute_metrics(prediction) + + logger.info(f"Step... ({cur_step}/{total_steps} | Evaluation metrics: {eval_metrics})") + + if jax.process_index() == 0: + write_eval_metric(summary_writer, eval_metrics, cur_step) + + if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(unreplicate(state.params)) + model.save_pretrained( + training_args.output_dir, + params=params, + push_to_hub=training_args.push_to_hub, + commit_message=f"Saving weights and logs of step {cur_step}", + ) + epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" + # endregion + + +if __name__ == "__main__": + main() diff --git a/examples/flax/question-answering/utils_qa.py b/examples/flax/question-answering/utils_qa.py new file mode 100644 index 0000000000..1157849c99 --- /dev/null +++ b/examples/flax/question-answering/utils_qa.py @@ -0,0 +1,427 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Post-processing utilities for question answering. +""" +import collections +import json +import logging +import os +from typing import Optional, Tuple + +import numpy as np +from tqdm.auto import tqdm + + +logger = logging.getLogger(__name__) + + +def postprocess_qa_predictions( + examples, + features, + predictions: Tuple[np.ndarray, np.ndarray], + version_2_with_negative: bool = False, + n_best_size: int = 20, + max_answer_length: int = 30, + null_score_diff_threshold: float = 0.0, + output_dir: Optional[str] = None, + prefix: Optional[str] = None, + log_level: Optional[int] = logging.WARNING, +): + """ + Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the + original contexts. This is the base postprocessing functions for models that only return start and end logits. + + Args: + examples: The non-preprocessed dataset (see the main script for more information). + features: The processed dataset (see the main script for more information). + predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): + The predictions of the model: two arrays containing the start logits and the end logits respectively. Its + first dimension must match the number of elements of :obj:`features`. + version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the underlying dataset contains examples with no answers. + n_best_size (:obj:`int`, `optional`, defaults to 20): + The total number of n-best predictions to generate when looking for an answer. + max_answer_length (:obj:`int`, `optional`, defaults to 30): + The maximum length of an answer that can be generated. This is needed because the start and end predictions + are not conditioned on one another. + null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0): + The threshold used to select the null answer: if the best answer has a score that is less than the score of + the null answer minus this threshold, the null answer is selected for this example (note that the score of + the null answer for an example giving several features is the minimum of the scores for the null answer on + each feature: all features must be aligned on the fact they `want` to predict a null answer). + + Only useful when :obj:`version_2_with_negative` is :obj:`True`. + output_dir (:obj:`str`, `optional`): + If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if + :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null + answers, are saved in `output_dir`. + prefix (:obj:`str`, `optional`): + If provided, the dictionaries mentioned above are saved with `prefix` added to their names. + log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): + ``logging`` log level (e.g., ``logging.WARNING``) + """ + assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)." + all_start_logits, all_end_logits = predictions + + assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features." + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + features_per_example = collections.defaultdict(list) + for i, feature in enumerate(features): + features_per_example[example_id_to_index[feature["example_id"]]].append(i) + + # The dictionaries we have to fill. + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + if version_2_with_negative: + scores_diff_json = collections.OrderedDict() + + # Logging. + logger.setLevel(log_level) + logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") + + # Let's loop over all the examples! + for example_index, example in enumerate(tqdm(examples)): + # Those are the indices of the features associated to the current example. + feature_indices = features_per_example[example_index] + + min_null_prediction = None + prelim_predictions = [] + + # Looping through all the features associated to the current example. + for feature_index in feature_indices: + # We grab the predictions of the model for this feature. + start_logits = all_start_logits[feature_index] + end_logits = all_end_logits[feature_index] + # This is what will allow us to map some the positions in our logits to span of texts in the original + # context. + offset_mapping = features[feature_index]["offset_mapping"] + # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context + # available in the current feature. + token_is_max_context = features[feature_index].get("token_is_max_context", None) + + # Update minimum null prediction. + feature_null_score = start_logits[0] + end_logits[0] + if min_null_prediction is None or min_null_prediction["score"] > feature_null_score: + min_null_prediction = { + "offsets": (0, 0), + "score": feature_null_score, + "start_logit": start_logits[0], + "end_logit": end_logits[0], + } + + # Go through all possibilities for the `n_best_size` greater start and end logits. + start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() + end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() + for start_index in start_indexes: + for end_index in end_indexes: + # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond + # to part of the input_ids that are not in the context. + if ( + start_index >= len(offset_mapping) + or end_index >= len(offset_mapping) + or offset_mapping[start_index] is None + or offset_mapping[end_index] is None + ): + continue + # Don't consider answers with a length that is either < 0 or > max_answer_length. + if end_index < start_index or end_index - start_index + 1 > max_answer_length: + continue + # Don't consider answer that don't have the maximum context available (if such information is + # provided). + if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): + continue + prelim_predictions.append( + { + "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), + "score": start_logits[start_index] + end_logits[end_index], + "start_logit": start_logits[start_index], + "end_logit": end_logits[end_index], + } + ) + if version_2_with_negative: + # Add the minimum null prediction + prelim_predictions.append(min_null_prediction) + null_score = min_null_prediction["score"] + + # Only keep the best `n_best_size` predictions. + predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] + + # Add back the minimum null prediction if it was removed because of its low score. + if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions): + predictions.append(min_null_prediction) + + # Use the offsets to gather the answer text in the original context. + context = example["context"] + for pred in predictions: + offsets = pred.pop("offsets") + pred["text"] = context[offsets[0] : offsets[1]] + + # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid + # failure. + if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): + predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}) + + # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using + # the LogSumExp trick). + scores = np.array([pred.pop("score") for pred in predictions]) + exp_scores = np.exp(scores - np.max(scores)) + probs = exp_scores / exp_scores.sum() + + # Include the probabilities in our predictions. + for prob, pred in zip(probs, predictions): + pred["probability"] = prob + + # Pick the best prediction. If the null answer is not possible, this is easy. + if not version_2_with_negative: + all_predictions[example["id"]] = predictions[0]["text"] + else: + # Otherwise we first need to find the best non-empty prediction. + i = 0 + while predictions[i]["text"] == "": + i += 1 + best_non_null_pred = predictions[i] + + # Then we compare to the null prediction using the threshold. + score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"] + scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable. + if score_diff > null_score_diff_threshold: + all_predictions[example["id"]] = "" + else: + all_predictions[example["id"]] = best_non_null_pred["text"] + + # Make `predictions` JSON-serializable by casting np.float back to float. + all_nbest_json[example["id"]] = [ + {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} + for pred in predictions + ] + + # If we have an output_dir, let's save all those dicts. + if output_dir is not None: + assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + + prediction_file = os.path.join( + output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" + ) + nbest_file = os.path.join( + output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" + ) + if version_2_with_negative: + null_odds_file = os.path.join( + output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" + ) + + logger.info(f"Saving predictions to {prediction_file}.") + with open(prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + logger.info(f"Saving nbest_preds to {nbest_file}.") + with open(nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + if version_2_with_negative: + logger.info(f"Saving null_odds to {null_odds_file}.") + with open(null_odds_file, "w") as writer: + writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + + return all_predictions + + +def postprocess_qa_predictions_with_beam_search( + examples, + features, + predictions: Tuple[np.ndarray, np.ndarray], + version_2_with_negative: bool = False, + n_best_size: int = 20, + max_answer_length: int = 30, + start_n_top: int = 5, + end_n_top: int = 5, + output_dir: Optional[str] = None, + prefix: Optional[str] = None, + log_level: Optional[int] = logging.WARNING, +): + """ + Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the + original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as + cls token predictions. + + Args: + examples: The non-preprocessed dataset (see the main script for more information). + features: The processed dataset (see the main script for more information). + predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): + The predictions of the model: two arrays containing the start logits and the end logits respectively. Its + first dimension must match the number of elements of :obj:`features`. + version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the underlying dataset contains examples with no answers. + n_best_size (:obj:`int`, `optional`, defaults to 20): + The total number of n-best predictions to generate when looking for an answer. + max_answer_length (:obj:`int`, `optional`, defaults to 30): + The maximum length of an answer that can be generated. This is needed because the start and end predictions + are not conditioned on one another. + start_n_top (:obj:`int`, `optional`, defaults to 5): + The number of top start logits too keep when searching for the :obj:`n_best_size` predictions. + end_n_top (:obj:`int`, `optional`, defaults to 5): + The number of top end logits too keep when searching for the :obj:`n_best_size` predictions. + output_dir (:obj:`str`, `optional`): + If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if + :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null + answers, are saved in `output_dir`. + prefix (:obj:`str`, `optional`): + If provided, the dictionaries mentioned above are saved with `prefix` added to their names. + log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): + ``logging`` log level (e.g., ``logging.WARNING``) + """ + assert len(predictions) == 5, "`predictions` should be a tuple with five elements." + start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions + + assert len(predictions[0]) == len( + features + ), f"Got {len(predictions[0])} predicitions and {len(features)} features." + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + features_per_example = collections.defaultdict(list) + for i, feature in enumerate(features): + features_per_example[example_id_to_index[feature["example_id"]]].append(i) + + # The dictionaries we have to fill. + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + scores_diff_json = collections.OrderedDict() if version_2_with_negative else None + + # Logging. + logger.setLevel(log_level) + logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") + + # Let's loop over all the examples! + for example_index, example in enumerate(tqdm(examples)): + # Those are the indices of the features associated to the current example. + feature_indices = features_per_example[example_index] + + min_null_score = None + prelim_predictions = [] + + # Looping through all the features associated to the current example. + for feature_index in feature_indices: + # We grab the predictions of the model for this feature. + start_log_prob = start_top_log_probs[feature_index] + start_indexes = start_top_index[feature_index] + end_log_prob = end_top_log_probs[feature_index] + end_indexes = end_top_index[feature_index] + feature_null_score = cls_logits[feature_index] + # This is what will allow us to map some the positions in our logits to span of texts in the original + # context. + offset_mapping = features[feature_index]["offset_mapping"] + # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context + # available in the current feature. + token_is_max_context = features[feature_index].get("token_is_max_context", None) + + # Update minimum null prediction + if min_null_score is None or feature_null_score < min_null_score: + min_null_score = feature_null_score + + # Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits. + for i in range(start_n_top): + for j in range(end_n_top): + start_index = int(start_indexes[i]) + j_index = i * end_n_top + j + end_index = int(end_indexes[j_index]) + # Don't consider out-of-scope answers (last part of the test should be unnecessary because of the + # p_mask but let's not take any risk) + if ( + start_index >= len(offset_mapping) + or end_index >= len(offset_mapping) + or offset_mapping[start_index] is None + or offset_mapping[end_index] is None + ): + continue + # Don't consider answers with a length negative or > max_answer_length. + if end_index < start_index or end_index - start_index + 1 > max_answer_length: + continue + # Don't consider answer that don't have the maximum context available (if such information is + # provided). + if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): + continue + prelim_predictions.append( + { + "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), + "score": start_log_prob[i] + end_log_prob[j_index], + "start_log_prob": start_log_prob[i], + "end_log_prob": end_log_prob[j_index], + } + ) + + # Only keep the best `n_best_size` predictions. + predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] + + # Use the offsets to gather the answer text in the original context. + context = example["context"] + for pred in predictions: + offsets = pred.pop("offsets") + pred["text"] = context[offsets[0] : offsets[1]] + + # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid + # failure. + if len(predictions) == 0: + predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6}) + + # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using + # the LogSumExp trick). + scores = np.array([pred.pop("score") for pred in predictions]) + exp_scores = np.exp(scores - np.max(scores)) + probs = exp_scores / exp_scores.sum() + + # Include the probabilities in our predictions. + for prob, pred in zip(probs, predictions): + pred["probability"] = prob + + # Pick the best prediction and set the probability for the null answer. + all_predictions[example["id"]] = predictions[0]["text"] + if version_2_with_negative: + scores_diff_json[example["id"]] = float(min_null_score) + + # Make `predictions` JSON-serializable by casting np.float back to float. + all_nbest_json[example["id"]] = [ + {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} + for pred in predictions + ] + + # If we have an output_dir, let's save all those dicts. + if output_dir is not None: + assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + + prediction_file = os.path.join( + output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" + ) + nbest_file = os.path.join( + output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" + ) + if version_2_with_negative: + null_odds_file = os.path.join( + output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" + ) + + logger.info(f"Saving predictions to {prediction_file}.") + with open(prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + logger.info(f"Saving nbest_preds to {nbest_file}.") + with open(nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + if version_2_with_negative: + logger.info(f"Saving null_odds to {null_odds_file}.") + with open(null_odds_file, "w") as writer: + writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + + return all_predictions, scores_diff_json diff --git a/utils/check_copies.py b/utils/check_copies.py index 9d19fba518..2008b0996e 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -28,7 +28,10 @@ PATH_TO_DOCS = "docs/source" REPO_PATH = "." # Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with) -FULL_COPIES = {"examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py"} +FULL_COPIES = { + "examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py", + "examples/flax/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py", +} LOCALIZED_READMES = {