From 488a179ce10ab1da4eae4b5945e141fc9e0e9283 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 16 Jan 2023 15:04:27 +0100 Subject: [PATCH] Fixing batching pipelines on single items for ChunkPipeline (#21132) * Fixing #20783 * Update src/transformers/pipelines/base.py * Fixing some tests. * Fixup. * Remove ffmpeg dep + a bit more relaxed for bigbird QA precision. * Better dataset. * Prevent failing on TF. * Better condition. We can't use `can_use_iterator` since we cannot use it directly. --- src/transformers/pipelines/base.py | 8 +++++++ tests/pipelines/test_pipelines_common.py | 24 +++++++++++++++++++ .../test_pipelines_question_answering.py | 4 +++- 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 038da8865f..28d6ee1937 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -1072,6 +1072,14 @@ class Pipeline(_ScikitCompat): ) elif is_iterable: return self.iterate(inputs, preprocess_params, forward_params, postprocess_params) + elif self.framework == "pt" and isinstance(self, ChunkPipeline): + return next( + iter( + self.get_iterator( + [inputs], num_workers, batch_size, preprocess_params, forward_params, postprocess_params + ) + ) + ) else: return self.run_single(inputs, preprocess_params, forward_params, postprocess_params) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index c06bd644c6..f5e75381e3 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -26,6 +26,7 @@ from functools import lru_cache from pathlib import Path from unittest import skipIf +import datasets import numpy as np from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token @@ -965,6 +966,29 @@ class CustomPipelineTest(unittest.TestCase): self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter.other_request_count, 0) + @require_torch + def test_chunk_pipeline_batching_single_file(self): + # Make sure we have cached the pipeline. + pipe = pipeline(model="hf-internal-testing/tiny-random-Wav2Vec2ForCTC") + ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") + audio = ds[40]["audio"]["array"] + + pipe = pipeline(model="hf-internal-testing/tiny-random-Wav2Vec2ForCTC") + # For some reason scoping doesn't work if not using `self.` + self.COUNT = 0 + forward = pipe.model.forward + + def new_forward(*args, **kwargs): + self.COUNT += 1 + return forward(*args, **kwargs) + + pipe.model.forward = new_forward + + for out in pipe(audio, return_timestamps="char", chunk_length_s=3, stride_length_s=[1, 1], batch_size=1024): + pass + + self.assertEqual(self.COUNT, 1) + @require_torch @is_staging_test diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py index afb7b95731..496b1685d9 100644 --- a/tests/pipelines/test_pipelines_question_answering.py +++ b/tests/pipelines/test_pipelines_question_answering.py @@ -106,11 +106,13 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)}) # Using batch is OK + if question_answerer.tokenizer.pad_token_id is None: + question_answerer.tokenizer.pad_token_id = question_answerer.model.config.eos_token_id new_outputs = question_answerer( question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." * 20, batch_size=2 ) self.assertEqual(new_outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)}) - self.assertEqual(outputs, new_outputs) + self.assertEqual(nested_simplify(outputs), nested_simplify(new_outputs)) @require_torch def test_small_model_pt(self):