[DX fix] Fixing QA pipeline streaming a dataset. (#18516)
* [DX fix] Fixing QA pipeline streaming a dataset. QuestionAnsweringArgumentHandler would iterate over the whole dataset effectively killing all properties of the pipeline. This restores nice properties when using `Dataset` or `Generator` since those are meant to be consumed lazily. * Handling TF better.
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import types
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
@@ -22,8 +23,11 @@ if is_tf_available():
|
|||||||
|
|
||||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||||
|
|
||||||
|
Dataset = None
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||||
|
|
||||||
@@ -82,6 +86,11 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown arguments {kwargs}")
|
raise ValueError(f"Unknown arguments {kwargs}")
|
||||||
|
|
||||||
|
# When user is sending a generator we need to trust it's a valid example
|
||||||
|
generator_types = (types.GeneratorType, Dataset) if Dataset is not None else (types.GeneratorType,)
|
||||||
|
if isinstance(inputs, generator_types):
|
||||||
|
return inputs
|
||||||
|
|
||||||
# Normalize inputs
|
# Normalize inputs
|
||||||
if isinstance(inputs, dict):
|
if isinstance(inputs, dict):
|
||||||
inputs = [inputs]
|
inputs = [inputs]
|
||||||
@@ -245,12 +254,18 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Convert inputs to features
|
# Convert inputs to features
|
||||||
|
|
||||||
examples = self._args_parser(*args, **kwargs)
|
examples = self._args_parser(*args, **kwargs)
|
||||||
if len(examples) == 1:
|
if isinstance(examples, (list, tuple)) and len(examples) == 1:
|
||||||
return super().__call__(examples[0], **kwargs)
|
return super().__call__(examples[0], **kwargs)
|
||||||
return super().__call__(examples, **kwargs)
|
return super().__call__(examples, **kwargs)
|
||||||
|
|
||||||
def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_question_len=64, max_seq_len=None):
|
def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_question_len=64, max_seq_len=None):
|
||||||
|
# XXX: This is specal, args_parser will not handle anything generator or dataset like
|
||||||
|
# For those we expect user to send a simple valid example either directly as a SquadExample or simple dict.
|
||||||
|
# So we still need a little sanitation here.
|
||||||
|
if isinstance(example, dict):
|
||||||
|
example = SquadExample(None, example["question"], example["context"], None, None, None)
|
||||||
|
|
||||||
if max_seq_len is None:
|
if max_seq_len is None:
|
||||||
max_seq_len = min(self.tokenizer.model_max_length, 384)
|
max_seq_len = min(self.tokenizer.model_max_length, 384)
|
||||||
|
|||||||
@@ -125,6 +125,18 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
|||||||
|
|
||||||
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt_iterator(self):
|
||||||
|
# https://github.com/huggingface/transformers/issues/18510
|
||||||
|
pipe = pipeline(model="sshleifer/tiny-distilbert-base-cased-distilled-squad", batch_size=16, framework="pt")
|
||||||
|
|
||||||
|
def data():
|
||||||
|
for i in range(10):
|
||||||
|
yield {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."}
|
||||||
|
|
||||||
|
for outputs in pipe(data()):
|
||||||
|
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_pt_softmax_trick(self):
|
def test_small_model_pt_softmax_trick(self):
|
||||||
question_answerer = pipeline(
|
question_answerer = pipeline(
|
||||||
|
|||||||
Reference in New Issue
Block a user