[DX fix] Fixing QA pipeline streaming a dataset. (#18516)

* [DX fix] Fixing QA pipeline streaming a dataset.

QuestionAnsweringArgumentHandler would iterate over the whole dataset
effectively killing all properties of the pipeline.
This restores nice properties when using `Dataset` or `Generator` since
those are meant to be consumed lazily.

* Handling TF better.
This commit is contained in:
Nicolas Patry
2022-08-08 14:25:56 +02:00
committed by GitHub
parent 88a0ce57bb
commit a4562552eb
2 changed files with 28 additions and 1 deletions

View File

@@ -125,6 +125,18 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
@require_torch
def test_small_model_pt_iterator(self):
# https://github.com/huggingface/transformers/issues/18510
pipe = pipeline(model="sshleifer/tiny-distilbert-base-cased-distilled-squad", batch_size=16, framework="pt")
def data():
for i in range(10):
yield {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."}
for outputs in pipe(data()):
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
@require_torch
def test_small_model_pt_softmax_trick(self):
question_answerer = pipeline(