Token level timestamps for long-form generation in Whisper (#29148)
This commit is contained in:
committed by
GitHub
parent
8a1faf2803
commit
ddf7ac4237
@@ -1969,6 +1969,56 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
|
||||
|
||||
@slow
|
||||
def test_tiny_token_timestamp_generation_longform(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model.to(torch_device)
|
||||
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||
|
||||
input_speech = self._load_datasamples(5)
|
||||
long_input_speech = np.concatenate(input_speech, dtype=np.float32)
|
||||
inputs = processor.feature_extractor(
|
||||
raw_speech=long_input_speech,
|
||||
return_tensors="pt",
|
||||
truncation=False, # False so the audio isn't truncated and whole audio is sent to the model
|
||||
return_attention_mask=True,
|
||||
padding=True,
|
||||
)
|
||||
|
||||
inputs = inputs.to(torch_device)
|
||||
generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True)
|
||||
|
||||
token_timestamps_shape = [
|
||||
[segment["token_timestamps"].shape for segment in segment_list]
|
||||
for segment_list in generate_outputs["segments"]
|
||||
]
|
||||
tokens_shape = [
|
||||
[segment["tokens"].shape for segment in segment_list] for segment_list in generate_outputs["segments"]
|
||||
]
|
||||
self.assertListEqual(tokens_shape, token_timestamps_shape)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = [
|
||||
torch.tensor([0.0000, 0.4200, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5200, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600]),
|
||||
torch.tensor([ 6.5400, 6.5400, 6.7400, 6.9600, 7.2600, 7.3400, 7.5800, 7.5800, 7.6400, 7.8400, 8.1000, 8.5000, 9.0000, 9.4800, 9.7200, 10.2600, 11.1000]),
|
||||
torch.tensor([11.2200, 11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1800, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800]),
|
||||
torch.tensor([16.9800, 16.9800, 17.3200, 18.1600, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8600, 23.7000]),
|
||||
torch.tensor([23.7000, 23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.4000, 26.5800, 26.7600, 27.1400, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200]),
|
||||
torch.tensor([29.4400, 29.4400, 29.7000, 30.0800, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.4800, 32.6200, 33.6800]),
|
||||
torch.tensor([33.8000, 33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600]),
|
||||
torch.tensor([40.5200, 40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000]),
|
||||
torch.tensor([44.7000, 44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400]),
|
||||
torch.tensor([50.5400, 50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400]),
|
||||
torch.tensor([52.9600, 52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1800, 58.5000, 58.6400, 58.8200]),
|
||||
torch.tensor([58.6800, 58.6800, 59.1400, 59.5400, 59.9200, 60.1600, 60.3800, 60.8200, 61.6200, 62.2600, 75.2000]),
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
for segment, exp_segment in zip(generate_outputs["segments"][0], EXPECTED_OUTPUT):
|
||||
self.assertTrue(torch.allclose(segment["token_timestamps"], exp_segment))
|
||||
|
||||
@slow
|
||||
def test_tiny_specaugment_librispeech(self):
|
||||
torch_device = "cpu"
|
||||
|
||||
@@ -361,6 +361,70 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_return_timestamps_in_preprocess_longform(self):
|
||||
pipe = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="openai/whisper-tiny.en",
|
||||
)
|
||||
data = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
|
||||
samples = [next(iter(data)) for _ in range(8)]
|
||||
audio = np.concatenate([sample["audio"]["array"] for sample in samples])
|
||||
|
||||
res = pipe(audio)
|
||||
expected_output = {
|
||||
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
|
||||
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
|
||||
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
|
||||
"the tents. Concord returned to its place amidst the tents."
|
||||
}
|
||||
self.assertEqual(res, expected_output)
|
||||
res = pipe(audio, return_timestamps=True)
|
||||
self.assertEqual(
|
||||
res,
|
||||
{
|
||||
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents.",
|
||||
"chunks": [
|
||||
{"timestamp": (0.0, 3.22), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (3.22, 6.74), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (6.74, 10.26), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (10.26, 13.78), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (13.78, 17.3), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (17.3, 20.82), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (20.82, 24.34), "text": " Concord returned to its place amidst the tents."},
|
||||
{"timestamp": (24.34, 27.86), "text": " Concord returned to its place amidst the tents."},
|
||||
],
|
||||
},
|
||||
)
|
||||
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||
res = pipe(audio, return_timestamps="word")
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
res["chunks"][:15],
|
||||
[
|
||||
{"text": " Concord", "timestamp": (0.5, 0.94)},
|
||||
{"text": " returned", "timestamp": (0.94, 1.52)},
|
||||
{"text": " to", "timestamp": (1.52, 1.78)},
|
||||
{"text": " its", "timestamp": (1.78, 1.98)},
|
||||
{"text": " place", "timestamp": (1.98, 2.16)},
|
||||
{"text": " amidst", "timestamp": (2.16, 2.5)},
|
||||
{"text": " the", "timestamp": (2.5, 2.9)},
|
||||
{"text": " tents.", "timestamp": (2.9, 4.2)},
|
||||
{"text": " Concord", "timestamp": (4.2, 4.5)},
|
||||
{"text": " returned", "timestamp": (4.5, 5.0)},
|
||||
{"text": " to", "timestamp": (5.0, 5.28)},
|
||||
{"text": " its", "timestamp": (5.28, 5.48)},
|
||||
{"text": " place", "timestamp": (5.48, 5.7)},
|
||||
{"text": " amidst", "timestamp": (5.7, 6.02)},
|
||||
{"text": " the", "timestamp": (6.02, 6.4)}
|
||||
|
||||
|
||||
],
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@require_torch
|
||||
def test_return_timestamps_in_init(self):
|
||||
# segment-level timestamps are accepted
|
||||
|
||||
Reference in New Issue
Block a user