Fix distil whisper segment computation (#33920)
* Fix distil whisper segment computation * [run-slow] whisper
This commit is contained in:
@@ -994,7 +994,10 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
for v in range(len(values)):
|
for v in range(len(values)):
|
||||||
layer_past_key_values = []
|
layer_past_key_values = []
|
||||||
for w in values[v]:
|
for w in values[v]:
|
||||||
|
if len(w) != 0:
|
||||||
layer_past_key_values.append(w[batch_idx][None].cpu())
|
layer_past_key_values.append(w[batch_idx][None].cpu())
|
||||||
|
else:
|
||||||
|
layer_past_key_values.append(w)
|
||||||
all_past_key_values.append(tuple(layer_past_key_values))
|
all_past_key_values.append(tuple(layer_past_key_values))
|
||||||
return tuple(all_past_key_values)
|
return tuple(all_past_key_values)
|
||||||
|
|
||||||
|
|||||||
@@ -2100,6 +2100,21 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_distil_token_timestamp_generation(self):
|
||||||
|
# we actually just want to check that returning segments with distil model works
|
||||||
|
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v3")
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v3")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
input_speech = np.concatenate(self._load_datasamples(4))
|
||||||
|
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||||
|
input_features = input_features.to(torch_device)
|
||||||
|
|
||||||
|
_ = model.generate(
|
||||||
|
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True, return_segments=True
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_tiny_longform_timestamps_generation(self):
|
def test_tiny_longform_timestamps_generation(self):
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user