[DX fix] Fixing QA pipeline streaming a dataset. (#18516)
* [DX fix] Fixing QA pipeline streaming a dataset. QuestionAnsweringArgumentHandler would iterate over the whole dataset effectively killing all properties of the pipeline. This restores nice properties when using `Dataset` or `Generator` since those are meant to be consumed lazily. * Handling TF better.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import types
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
@@ -22,8 +23,11 @@ if is_tf_available():
|
||||
|
||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
|
||||
Dataset = None
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
|
||||
@@ -82,6 +86,11 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
|
||||
else:
|
||||
raise ValueError(f"Unknown arguments {kwargs}")
|
||||
|
||||
# When user is sending a generator we need to trust it's a valid example
|
||||
generator_types = (types.GeneratorType, Dataset) if Dataset is not None else (types.GeneratorType,)
|
||||
if isinstance(inputs, generator_types):
|
||||
return inputs
|
||||
|
||||
# Normalize inputs
|
||||
if isinstance(inputs, dict):
|
||||
inputs = [inputs]
|
||||
@@ -245,12 +254,18 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
||||
"""
|
||||
|
||||
# Convert inputs to features
|
||||
|
||||
examples = self._args_parser(*args, **kwargs)
|
||||
if len(examples) == 1:
|
||||
if isinstance(examples, (list, tuple)) and len(examples) == 1:
|
||||
return super().__call__(examples[0], **kwargs)
|
||||
return super().__call__(examples, **kwargs)
|
||||
|
||||
def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_question_len=64, max_seq_len=None):
|
||||
# XXX: This is specal, args_parser will not handle anything generator or dataset like
|
||||
# For those we expect user to send a simple valid example either directly as a SquadExample or simple dict.
|
||||
# So we still need a little sanitation here.
|
||||
if isinstance(example, dict):
|
||||
example = SquadExample(None, example["question"], example["context"], None, None, None)
|
||||
|
||||
if max_seq_len is None:
|
||||
max_seq_len = min(self.tokenizer.model_max_length, 384)
|
||||
|
||||
@@ -125,6 +125,18 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
|
||||
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_iterator(self):
|
||||
# https://github.com/huggingface/transformers/issues/18510
|
||||
pipe = pipeline(model="sshleifer/tiny-distilbert-base-cased-distilled-squad", batch_size=16, framework="pt")
|
||||
|
||||
def data():
|
||||
for i in range(10):
|
||||
yield {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."}
|
||||
|
||||
for outputs in pipe(data()):
|
||||
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_softmax_trick(self):
|
||||
question_answerer = pipeline(
|
||||
|
||||
Reference in New Issue
Block a user