From 5d3cb760a072be65ee48a1368b5ac9fb9b390acd Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 20 Jan 2023 10:31:40 +0100 Subject: [PATCH] [Whispe] Fix pipeline after timestamp merges (#21198) * pass return_timestamps to pre-process * add a test to test it * test does not need device 0 * remove failing bit * update test --- .../pipelines/automatic_speech_recognition.py | 6 ++++- ..._pipelines_automatic_speech_recognition.py | 23 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 371a59fd7d..f6c7b0167f 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -400,6 +400,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): " only 1 version" ) forward_params["generate_kwargs"].update(generate_kwargs) + if return_timestamps is not None: + forward_params["generate_kwargs"]["return_timestamps"] = return_timestamps postprocess_params = {} if decoder_kwargs is not None: @@ -523,6 +525,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): generate_kwargs = {} is_last = model_inputs.pop("is_last") + return_timestamps = generate_kwargs.pop("return_timestamps", False) + if self.type == "seq2seq": encoder = self.model.get_encoder() # Consume values so we can let extra information flow freely through @@ -552,7 +556,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): stride = model_inputs.pop("stride", None) tokens = self.model.generate( input_features=model_inputs.pop("input_features"), - logits_processor=[WhisperTimeStampLogitsProcessor()], + logits_processor=[WhisperTimeStampLogitsProcessor()] if return_timestamps else None, **generate_kwargs, ) out = {"tokens": tokens} diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index dc304272fd..d21c00f8ac 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -291,6 +291,29 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel output = speech_recognizer(filename) self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"}) + @require_torch + def test_return_timestamps_in_preprocess(self): + pipe = pipeline( + task="automatic-speech-recognition", + model="openai/whisper-tiny", + chunk_length_s=8, + stride_length_s=1, + ) + data = load_dataset("librispeech_asr", "clean", split="test", streaming=True) + sample = next(iter(data)) + pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language="en", task="transcribe") + + res = pipe(sample["audio"]["array"]) + self.assertEqual(res, {"text": " Conquered returned to its place amidst the tents."}) + res = pipe(sample["audio"]["array"], return_timestamps=True) + self.assertEqual( + res, + { + "text": " Conquered returned to its place amidst the tents.", + "chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}], + }, + ) + @require_torch @slow def test_torch_whisper(self):