Fix distil whisper segment computation (#33920)

* Fix distil whisper segment computation

* [run-slow] whisper
This commit is contained in:
Yoach Lacombe
2024-10-04 11:18:01 +02:00
committed by GitHub
parent 2bd4d5897d
commit 124713c32b
2 changed files with 19 additions and 1 deletions

View File

@@ -994,7 +994,10 @@ class WhisperGenerationMixin(GenerationMixin):
for v in range(len(values)):
layer_past_key_values = []
for w in values[v]:
if len(w) != 0:
layer_past_key_values.append(w[batch_idx][None].cpu())
else:
layer_past_key_values.append(w)
all_past_key_values.append(tuple(layer_past_key_values))
return tuple(all_past_key_values)

View File

@@ -2100,6 +2100,21 @@ class WhisperModelIntegrationTests(unittest.TestCase):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_distil_token_timestamp_generation(self):
# we actually just want to check that returning segments with distil model works
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v3")
model.to(torch_device)
input_speech = np.concatenate(self._load_datasamples(4))
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
input_features = input_features.to(torch_device)
_ = model.generate(
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True, return_segments=True
)
@slow
def test_tiny_longform_timestamps_generation(self):
set_seed(0)