Token level timestamps for long-form generation in Whisper (#29148)

This commit is contained in:
Raushan Turganbay
2024-02-27 23:15:26 +05:00
committed by GitHub
parent 8a1faf2803
commit ddf7ac4237
4 changed files with 141 additions and 3 deletions

View File

@@ -1969,6 +1969,56 @@ class WhisperModelIntegrationTests(unittest.TestCase):
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
@slow
def test_tiny_token_timestamp_generation_longform(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
input_speech = self._load_datasamples(5)
long_input_speech = np.concatenate(input_speech, dtype=np.float32)
inputs = processor.feature_extractor(
raw_speech=long_input_speech,
return_tensors="pt",
truncation=False, # False so the audio isn't truncated and whole audio is sent to the model
return_attention_mask=True,
padding=True,
)
inputs = inputs.to(torch_device)
generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True)
token_timestamps_shape = [
[segment["token_timestamps"].shape for segment in segment_list]
for segment_list in generate_outputs["segments"]
]
tokens_shape = [
[segment["tokens"].shape for segment in segment_list] for segment_list in generate_outputs["segments"]
]
self.assertListEqual(tokens_shape, token_timestamps_shape)
# fmt: off
EXPECTED_OUTPUT = [
torch.tensor([0.0000, 0.4200, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5200, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600]),
torch.tensor([ 6.5400, 6.5400, 6.7400, 6.9600, 7.2600, 7.3400, 7.5800, 7.5800, 7.6400, 7.8400, 8.1000, 8.5000, 9.0000, 9.4800, 9.7200, 10.2600, 11.1000]),
torch.tensor([11.2200, 11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1800, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800]),
torch.tensor([16.9800, 16.9800, 17.3200, 18.1600, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8600, 23.7000]),
torch.tensor([23.7000, 23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.4000, 26.5800, 26.7600, 27.1400, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200]),
torch.tensor([29.4400, 29.4400, 29.7000, 30.0800, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.4800, 32.6200, 33.6800]),
torch.tensor([33.8000, 33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600]),
torch.tensor([40.5200, 40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000]),
torch.tensor([44.7000, 44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400]),
torch.tensor([50.5400, 50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400]),
torch.tensor([52.9600, 52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1800, 58.5000, 58.6400, 58.8200]),
torch.tensor([58.6800, 58.6800, 59.1400, 59.5400, 59.9200, 60.1600, 60.3800, 60.8200, 61.6200, 62.2600, 75.2000]),
]
# fmt: on
for segment, exp_segment in zip(generate_outputs["segments"][0], EXPECTED_OUTPUT):
self.assertTrue(torch.allclose(segment["token_timestamps"], exp_segment))
@slow
def test_tiny_specaugment_librispeech(self):
torch_device = "cpu"

View File

@@ -361,6 +361,70 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
)
# fmt: on
@slow
@require_torch
def test_return_timestamps_in_preprocess_longform(self):
pipe = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny.en",
)
data = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
samples = [next(iter(data)) for _ in range(8)]
audio = np.concatenate([sample["audio"]["array"] for sample in samples])
res = pipe(audio)
expected_output = {
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents."
}
self.assertEqual(res, expected_output)
res = pipe(audio, return_timestamps=True)
self.assertEqual(
res,
{
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents.",
"chunks": [
{"timestamp": (0.0, 3.22), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (3.22, 6.74), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (6.74, 10.26), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (10.26, 13.78), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (13.78, 17.3), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (17.3, 20.82), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (20.82, 24.34), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (24.34, 27.86), "text": " Concord returned to its place amidst the tents."},
],
},
)
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
res = pipe(audio, return_timestamps="word")
# fmt: off
self.assertEqual(
res["chunks"][:15],
[
{"text": " Concord", "timestamp": (0.5, 0.94)},
{"text": " returned", "timestamp": (0.94, 1.52)},
{"text": " to", "timestamp": (1.52, 1.78)},
{"text": " its", "timestamp": (1.78, 1.98)},
{"text": " place", "timestamp": (1.98, 2.16)},
{"text": " amidst", "timestamp": (2.16, 2.5)},
{"text": " the", "timestamp": (2.5, 2.9)},
{"text": " tents.", "timestamp": (2.9, 4.2)},
{"text": " Concord", "timestamp": (4.2, 4.5)},
{"text": " returned", "timestamp": (4.5, 5.0)},
{"text": " to", "timestamp": (5.0, 5.28)},
{"text": " its", "timestamp": (5.28, 5.48)},
{"text": " place", "timestamp": (5.48, 5.7)},
{"text": " amidst", "timestamp": (5.7, 6.02)},
{"text": " the", "timestamp": (6.02, 6.4)}
],
)
# fmt: on
@require_torch
def test_return_timestamps_in_init(self):
# segment-level timestamps are accepted