From 1c2bb3ac54d18a0bfc3e212b73f8d1c4aac3ea48 Mon Sep 17 00:00:00 2001 From: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com> Date: Mon, 20 May 2024 10:53:58 +0200 Subject: [PATCH] add return_token_timestamps to WhisperProcessor (#30812) * compute num_frames in WhisperFeatureExtractor * add return_num_frames in WhisperFeatureProcessor + adapt pipeline * return_timestamps renaming + pipeline fix * fix * fix * fix * add tests * Update src/transformers/models/whisper/feature_extraction_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * apply review changes * fix * Update src/transformers/models/whisper/feature_extraction_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests/models/whisper/test_modeling_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * apply review * fix * review changes * Update src/transformers/models/whisper/feature_extraction_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * make style quality * EXPECTED_OUTPUT in single line * small numpy->torch fix * fix --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../whisper/feature_extraction_whisper.py | 8 ++ .../models/whisper/generation_whisper.py | 15 ++- .../pipelines/automatic_speech_recognition.py | 28 +++--- tests/models/whisper/test_modeling_whisper.py | 93 +++++++++++++++++++ 4 files changed, 127 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 508e85b91f..f2d6da5660 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -188,6 +188,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): sampling_rate: Optional[int] = None, do_normalize: Optional[bool] = None, device: Optional[str] = "cpu", + return_token_timestamps: Optional[bool] = None, **kwargs, ) -> BatchFeature: """ @@ -237,6 +238,9 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): device (`str`, *optional*, defaults to `'cpu'`): Specifies the device for computation of the log-mel spectrogram of audio signals in the `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda") + return_token_timestamps (`bool`, *optional*, defaults to `None`): + Whether or not to return the number of frames of the input raw_speech. + These num_frames can be used by the model to compute word level timestamps. """ if sampling_rate is not None: @@ -302,6 +306,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): if isinstance(input_features[0], List): padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features] + else: padded_inputs["input_features"] = input_features @@ -309,6 +314,9 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): # rescale from sample (48000) to feature (3000) padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length] + if return_token_timestamps is not None: + padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech] + if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index c58b0d35e5..2bdff6e534 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -209,11 +209,15 @@ class WhisperGenerationMixin: # 2. num_frames is different, compute the DTW matrix for each sample sequentially # we're using np.unique because num_frames can be int/list/tuple - if len(np.unique(num_frames)) == 1: - # if num_frames is the same, no need to recompute matrix, std and mean for each element of the batch - num_frames = num_frames if isinstance(num_frames, int) else num_frames[0] - + if isinstance(num_frames, int): weights = weights[..., : num_frames // 2] + + elif isinstance(num_frames, (list, tuple, np.ndarray)) and len(np.unique(num_frames)) == 1: + weights = weights[..., : num_frames[0] // 2] + + elif isinstance(num_frames, (torch.Tensor)) and len(torch.unique(num_frames)) == 1: + weights = weights[..., : num_frames[0] // 2] + else: # num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames) @@ -231,7 +235,7 @@ class WhisperGenerationMixin: # Perform dynamic time warping on each element of the batch. for batch_idx in range(batch_size): - if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray)): + if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray, torch.Tensor)): matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2] # Normalize and smoothen the weights. @@ -475,6 +479,7 @@ class WhisperGenerationMixin: "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.", FutureWarning, ) + # 1. prepare generation config generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index de1a9b57ac..123dbcdb67 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -443,11 +443,18 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): return_tensors="pt", ) else: - processed = self.feature_extractor( - inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" - ) - if stride is None: - extra["segment_size"] = len(inputs) + if self.type == "seq2seq_whisper" and stride is None: + processed = self.feature_extractor( + inputs, + sampling_rate=self.feature_extractor.sampling_rate, + return_tensors="pt", + return_token_timestamps=True, + ) + extra["num_frames"] = processed.pop("num_frames") + else: + processed = self.feature_extractor( + inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" + ) if self.torch_dtype is not None: processed = processed.to(dtype=self.torch_dtype) @@ -461,11 +468,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): attention_mask = model_inputs.pop("attention_mask", None) stride = model_inputs.pop("stride", None) - segment_size = model_inputs.pop("segment_size", None) + num_frames = model_inputs.pop("num_frames", None) is_last = model_inputs.pop("is_last") - if stride is not None and segment_size is not None: - raise ValueError("segment_size must be used only when stride is None") + if stride is not None and num_frames is not None: + raise ValueError("num_frames must be used only when stride is None") if self.type in {"seq2seq", "seq2seq_whisper"}: encoder = self.model.get_encoder() @@ -495,10 +502,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride] else: - if isinstance(segment_size, int): - generate_kwargs["num_frames"] = segment_size // self.feature_extractor.hop_length - else: - generate_kwargs["num_frames"] = segment_size[0] // self.feature_extractor.hop_length + generate_kwargs["num_frames"] = num_frames if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames: generate_kwargs["input_features"] = inputs diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index fed1b9c059..58acb5f2fd 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1950,6 +1950,69 @@ class WhisperModelIntegrationTests(unittest.TestCase): transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + @slow + def test_large_timestamp_generation(self): + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") + model.to(torch_device) + + input_speech = np.concatenate(self._load_datasamples(4)) + input_features = processor( + input_speech, return_tensors="pt", sampling_rate=16_000, return_token_timestamps=True + ).input_features + input_features = input_features.to(torch_device) + + generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu") + + # fmt: off + EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50360, 50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50257]) + # fmt: on + self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT)) + + EXPECTED_TRANSCRIPT = [ + { + "text": ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + " Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive" + " season of the year, with Christmas and roast beef looming before us, similes drawn from eating" + " and its results occur most readily to the mind. He has grave doubts whether Sir Frederick " + "Leighton's work is really Greek after all," + ), + "offsets": [ + { + "text": ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ), + "timestamp": (0.0, 5.28), + }, + { + "text": " Nor is Mr. Quilter's manner less interesting than his matter.", + "timestamp": (6.34, 10.1), + }, + { + "text": ( + " He tells us that at this festive season of the year, with Christmas and roast beef looming before us," + ), + "timestamp": (10.92, 17.6), + }, + { + "text": (" similes drawn from eating and its results occur most readily to the mind."), + "timestamp": (18.44, 22.580000000000002), + }, + { + "text": ( + " He has grave doubts whether Sir Frederick Leighton's work is really Greek after all," + ), + "timestamp": (23.16, 28.68), + }, + ], + } + ] + + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + @slow def test_tiny_token_timestamp_generation(self): set_seed(0) @@ -1979,6 +2042,36 @@ class WhisperModelIntegrationTests(unittest.TestCase): self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT)) + @slow + def test_large_token_timestamp_generation(self): + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") + model.to(torch_device) + + input_speech = self._load_datasamples(4) + input_features = processor( + input_speech, return_tensors="pt", sampling_rate=16_000, return_token_timestamps=True + ) + input_features = input_features.to(torch_device) + + generate_outputs = model.generate( + **input_features, max_length=448, return_timestamps=True, return_token_timestamps=True + ) + + self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape) + + # fmt: off + EXPECTED_OUTPUT = torch.tensor([ + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6200, 0.7400, 0.8600, 1.0000, 1.0400, 1.3000, 1.4400, 1.7800, 2.1800, 2.2800, 2.5000, 2.9200, 3.0000, 3.3800, 3.5000, 3.6000, 3.8400, 4.1000, 4.4000, 4.6800, 5.1400, 5.3600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6000, 0.9200, 1.2200, 1.3400, 1.4200, 1.5400, 1.5800, 1.7400, 2.0600, 2.3800, 3.0400, 3.3800, 3.6400, 4.1200, 4.3600, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5400, 0.8200, 1.1600, 1.4600, 1.7400, 1.8800, 2.3400, 2.7400, 3.1400, 3.2200, 3.5400, 4.2800, 4.5600, 4.8200, 5.0600, 5.3200, 5.6600, 5.9600, 6.1400, 6.4000, 6.8400, 7.8800, 8.0200, 8.3600, 8.7000, 9.0200, 9.3200, 9.5000, 9.8400, 10.3000, 10.6600, 11.0800, 11.3600, 11.4600, 11.8000, 12.4600], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5600, 0.7600, 1.0600, 1.4000, 1.8800, 2.2600, 2.6200, 2.8000, 2.9600, 3.0000, 3.2000, 3.4400, 3.6800, 4.0000, 4.6000, 5.0000, 5.3200, 5.4800, 6.0600, 6.0600, 6.1000, 6.3200, 6.7400, 7.0000, 7.2200, 7.4000, 7.7600, 8.0600, 8.5600, 8.8600, 8.9400, 9.1000, 9.3400, 9.8800, 9.8800, 9.8800] + ]) + # fmt: on + + self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT)) + @slow def test_tiny_token_timestamp_batch_generation(self): set_seed(0)