From 7f5d644e69068825bb5b6e84cdc56b3d3a9bd04f Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Mon, 29 Jul 2024 21:24:42 +0800 Subject: [PATCH] [pipeline] fix padding for 1-d tensors (#31776) * [pipeline] fix padding for 1-d tensors * add test * make style * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py --------- Co-authored-by: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com> --- src/transformers/pipelines/base.py | 3 +++ ...st_pipelines_automatic_speech_recognition.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 09f77402a1..85beb33b6f 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -90,6 +90,9 @@ def _pad(items, key, padding_value, padding_side): # Others include `attention_mask` etc... shape = items[0][key].shape dim = len(shape) + if dim == 1: + # We have a list of 1-dim torch tensors, which can be stacked without padding + return torch.cat([item[key] for item in items], dim=0) if key in ["pixel_values", "image"]: # This is probable image so padding shouldn't be necessary # B, C, H, W diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index d8810f67ee..777319d346 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -549,6 +549,23 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): output = speech_recognizer([filename], chunk_length_s=5, batch_size=4) self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}]) + @require_torch + @slow + def test_torch_whisper_batched(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="openai/whisper-tiny", + framework="pt", + ) + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:2]") + EXPECTED_OUTPUT = [ + {"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."}, + {"text": " Nor is Mr. Quilters' manner less interesting than his matter."}, + ] + + output = speech_recognizer(ds["audio"], batch_size=2) + self.assertEqual(output, EXPECTED_OUTPUT) + @slow def test_find_longest_common_subsequence(self): max_source_positions = 1500