From 124713c32b62416bf7a773676866fd53924bc472 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Fri, 4 Oct 2024 11:18:01 +0200 Subject: [PATCH] Fix distil whisper segment computation (#33920) * Fix distil whisper segment computation * [run-slow] whisper --- .../models/whisper/generation_whisper.py | 5 ++++- tests/models/whisper/test_modeling_whisper.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 32e54e0a12..a3de765137 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -994,7 +994,10 @@ class WhisperGenerationMixin(GenerationMixin): for v in range(len(values)): layer_past_key_values = [] for w in values[v]: - layer_past_key_values.append(w[batch_idx][None].cpu()) + if len(w) != 0: + layer_past_key_values.append(w[batch_idx][None].cpu()) + else: + layer_past_key_values.append(w) all_past_key_values.append(tuple(layer_past_key_values)) return tuple(all_past_key_values) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 2925de5f22..e0eb27813e 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2100,6 +2100,21 @@ class WhisperModelIntegrationTests(unittest.TestCase): transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + @slow + def test_distil_token_timestamp_generation(self): + # we actually just want to check that returning segments with distil model works + processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v3") + model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v3") + model.to(torch_device) + + input_speech = np.concatenate(self._load_datasamples(4)) + input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features + input_features = input_features.to(torch_device) + + _ = model.generate( + input_features, max_length=448, return_timestamps=True, return_token_timestamps=True, return_segments=True + ) + @slow def test_tiny_longform_timestamps_generation(self): set_seed(0)