From 9c8979e35fc4d0f991214368b58054573b8747ce Mon Sep 17 00:00:00 2001 From: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com> Date: Tue, 7 May 2024 11:17:27 +0200 Subject: [PATCH] Word-level timestamps broken for short-form audio (#30325) * force chunk_length_s in AutomaticSpeechRecognitionPipeline * compute num_frames even when stride is None * add slow tests * fix test * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add input validation * fixup * small fix --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../pipelines/automatic_speech_recognition.py | 12 ++ ..._pipelines_automatic_speech_recognition.py | 131 ++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index f2d0f13679..de1a9b57ac 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -446,6 +446,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): processed = self.feature_extractor( inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" ) + if stride is None: + extra["segment_size"] = len(inputs) if self.torch_dtype is not None: processed = processed.to(dtype=self.torch_dtype) @@ -459,8 +461,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): attention_mask = model_inputs.pop("attention_mask", None) stride = model_inputs.pop("stride", None) + segment_size = model_inputs.pop("segment_size", None) is_last = model_inputs.pop("is_last") + if stride is not None and segment_size is not None: + raise ValueError("segment_size must be used only when stride is None") + if self.type in {"seq2seq", "seq2seq_whisper"}: encoder = self.model.get_encoder() # Consume values so we can let extra information flow freely through @@ -488,6 +494,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): else: generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride] + else: + if isinstance(segment_size, int): + generate_kwargs["num_frames"] = segment_size // self.feature_extractor.hop_length + else: + generate_kwargs["num_frames"] = segment_size[0] // self.feature_extractor.hop_length + if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames: generate_kwargs["input_features"] = inputs else: diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index ddf9011808..a1ab294783 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -755,6 +755,94 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): }, ) + @slow + @require_torch + def test_whisper_large_timestamp_prediction(self): + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") + array = np.concatenate( + [ds[40]["audio"]["array"], ds[41]["audio"]["array"], ds[42]["audio"]["array"], ds[43]["audio"]["array"]] + ) + pipe = pipeline(model="openai/whisper-large-v3", return_timestamps=True) + + output = pipe(ds[40]["audio"]) + self.assertDictEqual( + output, + { + "text": " A man said to the universe, Sir, I exist.", + "chunks": [{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 4.08)}], + }, + ) + + output = pipe(array, chunk_length_s=10) + + self.assertDictEqual( + nested_simplify(output), + { + "chunks": [ + {"timestamp": (0.0, 2.0), "text": (" A man said to the universe,")}, + {"timestamp": (2.0, 4.1), "text": (" Sir, I exist.")}, + {"timestamp": (5.14, 5.96), "text": (" Sweat covered")}, + {"timestamp": (5.96, 8.02), "text": (" Breon's body, trickling into")}, + {"timestamp": (8.02, 10.67), "text": (" the tight loincloth that was the only garment he wore,")}, + {"timestamp": (10.67, 13.67), "text": (" the cut on his chest still dripping blood,")}, + {"timestamp": (13.67, 17.61), "text": (" the ache of his overstrained eyes.")}, + { + "timestamp": (17.61, 24.0), + "text": ( + " Even the soaring arena around him with thousands of spectators were trivialities not worth thinking about." + ), + }, + { + "timestamp": (24.0, 29.94), + "text": (" His instant of panic was followed by a small, sharp blow high on his chest."), + }, + ], + "text": ( + " A man said to the universe, Sir, I exist. Sweat covered Breon's" + " body, trickling into the tight loincloth that was the only garment" + " he wore, the cut on his chest still dripping blood, the ache of his" + " overstrained eyes. Even the soaring arena around him with thousands" + " of spectators were trivialities not worth thinking about. His " + "instant of panic was followed by a small, sharp blow high on his chest." + ), + }, + ) + + output = pipe(array) + self.assertDictEqual( + output, + { + "chunks": [ + {"timestamp": (0.0, 1.96), "text": " A man said to the universe,"}, + {"timestamp": (2.7, 4.1), "text": " Sir, I exist."}, + {"timestamp": (5.14, 6.84), "text": " Sweat covered Brion's body,"}, + { + "timestamp": (7.4, 10.68), + "text": " trickling into the tight loincloth that was the only garment he wore,", + }, + {"timestamp": (11.6, 13.94), "text": " the cut on his chest still dripping blood,"}, + {"timestamp": (14.78, 16.72), "text": " the ache of his overstrained eyes,"}, + { + "timestamp": (17.32, 21.16), + "text": " even the soaring arena around him with the thousands of spectators", + }, + {"timestamp": (21.16, 23.94), "text": " were trivialities not worth thinking about."}, + { + "timestamp": (24.42, 29.94), + "text": " His instant panic was followed by a small sharp blow high on his chest.", + }, + ], + "text": ( + " A man said to the universe, Sir, I exist. Sweat covered Brion's body," + " trickling into the tight loincloth that was the only garment he wore, " + "the cut on his chest still dripping blood, the ache of his overstrained " + "eyes, even the soaring arena around him with the thousands of spectators " + "were trivialities not worth thinking about. His instant panic was followed " + "by a small sharp blow high on his chest." + ), + }, + ) + @slow @require_torch def test_whisper_word_timestamps_batched(self): @@ -799,6 +887,49 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): output = pipe(sample, batch_size=2) self.assertDictEqual(output, EXPECTED_OUTPUT) + @slow + @require_torch + def test_whisper_large_word_timestamps_batched(self): + pipe = pipeline( + task="automatic-speech-recognition", + model="openai/whisper-large-v3", + return_timestamps="word", + ) + data = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + sample = data[0]["audio"] + + # not the same output as test_simple_whisper_asr because of chunking + EXPECTED_OUTPUT = { + "text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", + "chunks": [ + {"text": " Mr.", "timestamp": (0.0, 0.74)}, + {"text": " Quilter", "timestamp": (0.74, 1.04)}, + {"text": " is", "timestamp": (1.04, 1.3)}, + {"text": " the", "timestamp": (1.3, 1.44)}, + {"text": " apostle", "timestamp": (1.44, 1.74)}, + {"text": " of", "timestamp": (1.74, 2.18)}, + {"text": " the", "timestamp": (2.18, 2.28)}, + {"text": " middle", "timestamp": (2.28, 2.5)}, + {"text": " classes,", "timestamp": (2.5, 3.0)}, + {"text": " and", "timestamp": (3.0, 3.4)}, + {"text": " we", "timestamp": (3.4, 3.5)}, + {"text": " are", "timestamp": (3.5, 3.6)}, + {"text": " glad", "timestamp": (3.6, 3.84)}, + {"text": " to", "timestamp": (3.84, 4.1)}, + {"text": " welcome", "timestamp": (4.1, 4.4)}, + {"text": " his", "timestamp": (4.4, 4.7)}, + {"text": " gospel.", "timestamp": (4.7, 5.34)}, + ], + } + + # batch size 1: copy the audio sample since pipeline consumes it + output = pipe(sample.copy(), batch_size=1) + self.assertDictEqual(output, EXPECTED_OUTPUT) + + # batch size 2: input audio is chunked into smaller pieces so it's testing batching + output = pipe(sample, batch_size=2) + self.assertDictEqual(output, EXPECTED_OUTPUT) + @require_torch @slow def test_torch_speech_encoder_decoder(self):