From a4562552eb5efa8a12c61a3a7ebfd687dc72ee19 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 8 Aug 2022 14:25:56 +0200 Subject: [PATCH] [DX fix] Fixing QA pipeline streaming a dataset. (#18516) * [DX fix] Fixing QA pipeline streaming a dataset. QuestionAnsweringArgumentHandler would iterate over the whole dataset effectively killing all properties of the pipeline. This restores nice properties when using `Dataset` or `Generator` since those are meant to be consumed lazily. * Handling TF better. --- .../pipelines/question_answering.py | 17 ++++++++++++++++- .../test_pipelines_question_answering.py | 12 ++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index 0f5fbf0370..d58762035e 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -1,3 +1,4 @@ +import types import warnings from collections.abc import Iterable from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union @@ -22,8 +23,11 @@ if is_tf_available(): from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING + Dataset = None + if is_torch_available(): import torch + from torch.utils.data import Dataset from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING @@ -82,6 +86,11 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler): else: raise ValueError(f"Unknown arguments {kwargs}") + # When user is sending a generator we need to trust it's a valid example + generator_types = (types.GeneratorType, Dataset) if Dataset is not None else (types.GeneratorType,) + if isinstance(inputs, generator_types): + return inputs + # Normalize inputs if isinstance(inputs, dict): inputs = [inputs] @@ -245,12 +254,18 @@ class QuestionAnsweringPipeline(ChunkPipeline): """ # Convert inputs to features + examples = self._args_parser(*args, **kwargs) - if len(examples) == 1: + if isinstance(examples, (list, tuple)) and len(examples) == 1: return super().__call__(examples[0], **kwargs) return super().__call__(examples, **kwargs) def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_question_len=64, max_seq_len=None): + # XXX: This is specal, args_parser will not handle anything generator or dataset like + # For those we expect user to send a simple valid example either directly as a SquadExample or simple dict. + # So we still need a little sanitation here. + if isinstance(example, dict): + example = SquadExample(None, example["question"], example["context"], None, None, None) if max_seq_len is None: max_seq_len = min(self.tokenizer.model_max_length, 384) diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py index f34237612c..c3a0da2f2b 100644 --- a/tests/pipelines/test_pipelines_question_answering.py +++ b/tests/pipelines/test_pipelines_question_answering.py @@ -125,6 +125,18 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"}) + @require_torch + def test_small_model_pt_iterator(self): + # https://github.com/huggingface/transformers/issues/18510 + pipe = pipeline(model="sshleifer/tiny-distilbert-base-cased-distilled-squad", batch_size=16, framework="pt") + + def data(): + for i in range(10): + yield {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."} + + for outputs in pipe(data()): + self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"}) + @require_torch def test_small_model_pt_softmax_trick(self): question_answerer = pipeline(