Fix distil whisper segment computation (#33920)
* Fix distil whisper segment computation * [run-slow] whisper
This commit is contained in:
@@ -994,7 +994,10 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
for v in range(len(values)):
|
||||
layer_past_key_values = []
|
||||
for w in values[v]:
|
||||
if len(w) != 0:
|
||||
layer_past_key_values.append(w[batch_idx][None].cpu())
|
||||
else:
|
||||
layer_past_key_values.append(w)
|
||||
all_past_key_values.append(tuple(layer_past_key_values))
|
||||
return tuple(all_past_key_values)
|
||||
|
||||
|
||||
@@ -2100,6 +2100,21 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_distil_token_timestamp_generation(self):
|
||||
# we actually just want to check that returning segments with distil model works
|
||||
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v3")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v3")
|
||||
model.to(torch_device)
|
||||
|
||||
input_speech = np.concatenate(self._load_datasamples(4))
|
||||
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
input_features = input_features.to(torch_device)
|
||||
|
||||
_ = model.generate(
|
||||
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True, return_segments=True
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_tiny_longform_timestamps_generation(self):
|
||||
set_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user