Fixing batching pipelines on single items for ChunkPipeline (#21132)
* Fixing #20783 * Update src/transformers/pipelines/base.py * Fixing some tests. * Fixup. * Remove ffmpeg dep + a bit more relaxed for bigbird QA precision. * Better dataset. * Prevent failing on TF. * Better condition. We can't use `can_use_iterator` since we cannot use it directly.
This commit is contained in:
@@ -1072,6 +1072,14 @@ class Pipeline(_ScikitCompat):
|
|||||||
)
|
)
|
||||||
elif is_iterable:
|
elif is_iterable:
|
||||||
return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
|
return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
|
||||||
|
elif self.framework == "pt" and isinstance(self, ChunkPipeline):
|
||||||
|
return next(
|
||||||
|
iter(
|
||||||
|
self.get_iterator(
|
||||||
|
[inputs], num_workers, batch_size, preprocess_params, forward_params, postprocess_params
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from functools import lru_cache
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest import skipIf
|
from unittest import skipIf
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
|
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
|
||||||
@@ -965,6 +966,29 @@ class CustomPipelineTest(unittest.TestCase):
|
|||||||
self.assertEqual(counter.head_request_count, 1)
|
self.assertEqual(counter.head_request_count, 1)
|
||||||
self.assertEqual(counter.other_request_count, 0)
|
self.assertEqual(counter.other_request_count, 0)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_chunk_pipeline_batching_single_file(self):
|
||||||
|
# Make sure we have cached the pipeline.
|
||||||
|
pipe = pipeline(model="hf-internal-testing/tiny-random-Wav2Vec2ForCTC")
|
||||||
|
ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||||
|
audio = ds[40]["audio"]["array"]
|
||||||
|
|
||||||
|
pipe = pipeline(model="hf-internal-testing/tiny-random-Wav2Vec2ForCTC")
|
||||||
|
# For some reason scoping doesn't work if not using `self.`
|
||||||
|
self.COUNT = 0
|
||||||
|
forward = pipe.model.forward
|
||||||
|
|
||||||
|
def new_forward(*args, **kwargs):
|
||||||
|
self.COUNT += 1
|
||||||
|
return forward(*args, **kwargs)
|
||||||
|
|
||||||
|
pipe.model.forward = new_forward
|
||||||
|
|
||||||
|
for out in pipe(audio, return_timestamps="char", chunk_length_s=3, stride_length_s=[1, 1], batch_size=1024):
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.assertEqual(self.COUNT, 1)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
@@ -106,11 +106,13 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
|||||||
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
||||||
|
|
||||||
# Using batch is OK
|
# Using batch is OK
|
||||||
|
if question_answerer.tokenizer.pad_token_id is None:
|
||||||
|
question_answerer.tokenizer.pad_token_id = question_answerer.model.config.eos_token_id
|
||||||
new_outputs = question_answerer(
|
new_outputs = question_answerer(
|
||||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." * 20, batch_size=2
|
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." * 20, batch_size=2
|
||||||
)
|
)
|
||||||
self.assertEqual(new_outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
self.assertEqual(new_outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
||||||
self.assertEqual(outputs, new_outputs)
|
self.assertEqual(nested_simplify(outputs), nested_simplify(new_outputs))
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_pt(self):
|
def test_small_model_pt(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user