Token level timestamps for long-form generation in Whisper (#29148)

This commit is contained in:
Raushan Turganbay
2024-02-27 23:15:26 +05:00
committed by GitHub
parent 8a1faf2803
commit ddf7ac4237
4 changed files with 141 additions and 3 deletions

View File

@@ -1969,6 +1969,56 @@ class WhisperModelIntegrationTests(unittest.TestCase):
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
@slow
def test_tiny_token_timestamp_generation_longform(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
input_speech = self._load_datasamples(5)
long_input_speech = np.concatenate(input_speech, dtype=np.float32)
inputs = processor.feature_extractor(
raw_speech=long_input_speech,
return_tensors="pt",
truncation=False, # False so the audio isn't truncated and whole audio is sent to the model
return_attention_mask=True,
padding=True,
)
inputs = inputs.to(torch_device)
generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True)
token_timestamps_shape = [
[segment["token_timestamps"].shape for segment in segment_list]
for segment_list in generate_outputs["segments"]
]
tokens_shape = [
[segment["tokens"].shape for segment in segment_list] for segment_list in generate_outputs["segments"]
]
self.assertListEqual(tokens_shape, token_timestamps_shape)
# fmt: off
EXPECTED_OUTPUT = [
torch.tensor([0.0000, 0.4200, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5200, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600]),
torch.tensor([ 6.5400, 6.5400, 6.7400, 6.9600, 7.2600, 7.3400, 7.5800, 7.5800, 7.6400, 7.8400, 8.1000, 8.5000, 9.0000, 9.4800, 9.7200, 10.2600, 11.1000]),
torch.tensor([11.2200, 11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1800, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800]),
torch.tensor([16.9800, 16.9800, 17.3200, 18.1600, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8600, 23.7000]),
torch.tensor([23.7000, 23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.4000, 26.5800, 26.7600, 27.1400, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200]),
torch.tensor([29.4400, 29.4400, 29.7000, 30.0800, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.4800, 32.6200, 33.6800]),
torch.tensor([33.8000, 33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600]),
torch.tensor([40.5200, 40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000]),
torch.tensor([44.7000, 44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400]),
torch.tensor([50.5400, 50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400]),
torch.tensor([52.9600, 52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1800, 58.5000, 58.6400, 58.8200]),
torch.tensor([58.6800, 58.6800, 59.1400, 59.5400, 59.9200, 60.1600, 60.3800, 60.8200, 61.6200, 62.2600, 75.2000]),
]
# fmt: on
for segment, exp_segment in zip(generate_outputs["segments"][0], EXPECTED_OUTPUT):
self.assertTrue(torch.allclose(segment["token_timestamps"], exp_segment))
@slow
def test_tiny_specaugment_librispeech(self):
torch_device = "cpu"