From 896300177bf9f35feac4698370212a80a5ab6138 Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Wed, 22 Jul 2020 16:11:57 +0200 Subject: [PATCH] Expose padding_strategy on squad processor to fix QA pipeline performance regression (#5932) * Attempt to fix the way squad_convert_examples_to_features pad the elements for the QA pipeline. Signed-off-by: Morgan Funtowicz * Quality Signed-off-by: Morgan Funtowicz * Make the code easier to read and avoid testing multiple test the same thing. Signed-off-by: Morgan Funtowicz * missing enum value on truncation_strategy. Signed-off-by: Morgan Funtowicz * Rethinking for the easiest fix: expose the padding strategy on squad_convert_examples_to_features. Signed-off-by: Morgan Funtowicz * Remove unused imports. Signed-off-by: Morgan Funtowicz --- src/transformers/data/processors/squad.py | 26 ++++++++++++++++++----- src/transformers/pipelines.py | 2 ++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/transformers/data/processors/squad.py b/src/transformers/data/processors/squad.py index b6f32bdc98..190e37fc8c 100644 --- a/src/transformers/data/processors/squad.py +++ b/src/transformers/data/processors/squad.py @@ -9,6 +9,7 @@ from tqdm import tqdm from ...file_utils import is_tf_available, is_torch_available from ...tokenization_bert import whitespace_tokenize +from ...tokenization_utils_base import TruncationStrategy from .utils import DataProcessor @@ -87,7 +88,9 @@ def _is_whitespace(c): return False -def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length, is_training): +def squad_convert_example_to_features( + example, max_seq_length, doc_stride, max_query_length, padding_strategy, is_training +): features = [] if is_training and not example.is_impossible: # Get start and end position @@ -141,11 +144,21 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q span_doc_tokens = all_doc_tokens while len(spans) * doc_stride < len(all_doc_tokens): + # Define the side we want to truncate / pad and the text/pair sorting + if tokenizer.padding_side == "right": + texts = truncated_query + pairs = span_doc_tokens + truncation = TruncationStrategy.ONLY_SECOND.value + else: + texts = span_doc_tokens + pairs = truncated_query + truncation = TruncationStrategy.ONLY_FIRST.value + encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic - truncated_query if tokenizer.padding_side == "right" else span_doc_tokens, - span_doc_tokens if tokenizer.padding_side == "right" else truncated_query, - truncation="only_second" if tokenizer.padding_side == "right" else "only_first", - padding="max_length", + texts, + pairs, + truncation=truncation, + padding=padding_strategy, max_length=max_seq_length, return_overflowing_tokens=True, stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, @@ -285,6 +298,7 @@ def squad_convert_examples_to_features( doc_stride, max_query_length, is_training, + padding_strategy="max_length", return_dataset=False, threads=1, tqdm_enabled=True, @@ -300,6 +314,7 @@ def squad_convert_examples_to_features( doc_stride: The stride used when the context is too large and is split across several features. max_query_length: The maximum length of the query. is_training: whether to create features for model evaluation or model training. + padding_strategy: Default to "max_length". Which padding strategy to use return_dataset: Default False. Either 'pt' or 'tf'. if 'pt': returns a torch.data.TensorDataset, if 'tf': returns a tf.data.Dataset @@ -333,6 +348,7 @@ def squad_convert_examples_to_features( max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, + padding_strategy=padding_strategy, is_training=is_training, ) features = list( diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index eb0e60a3a4..164f94b33c 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -36,6 +36,7 @@ from .modelcard import ModelCard from .tokenization_auto import AutoTokenizer from .tokenization_bert import BasicTokenizer from .tokenization_utils import PreTrainedTokenizer +from .tokenization_utils_base import PaddingStrategy if is_tf_available(): @@ -1318,6 +1319,7 @@ class QuestionAnsweringPipeline(Pipeline): max_seq_length=kwargs["max_seq_len"], doc_stride=kwargs["doc_stride"], max_query_length=kwargs["max_question_len"], + padding_strategy=PaddingStrategy.DO_NOT_PAD.value, is_training=False, tqdm_enabled=False, )