Token level timestamps for long-form generation in Whisper (#29148)

This commit is contained in:
Raushan Turganbay
2024-02-27 23:15:26 +05:00
committed by GitHub
parent 8a1faf2803
commit ddf7ac4237
4 changed files with 141 additions and 3 deletions

View File

@@ -720,6 +720,7 @@ class WhisperGenerationMixin:
input_stride=input_stride, input_stride=input_stride,
prev_idx=prev_i, prev_idx=prev_i,
idx=i, idx=i,
return_token_timestamps=return_token_timestamps,
) )
current_segments[prev_i] += segments current_segments[prev_i] += segments
@@ -809,11 +810,15 @@ class WhisperGenerationMixin:
# remove eos token id # remove eos token id
if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
seek_sequence = seek_sequence[:-1] seek_sequence = seek_sequence[:-1]
if return_token_timestamps:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
# remove all padding tokens # remove all padding tokens
if seek_sequence[-1] == generation_config.pad_token_id: if seek_sequence[-1] == generation_config.pad_token_id:
num_paddings = (seek_sequence == generation_config.pad_token_id).sum() num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
seek_sequence = seek_sequence[:-num_paddings] seek_sequence = seek_sequence[:-num_paddings]
if return_token_timestamps:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
# check which sequences in batch need fallback & which should be skipped # check which sequences in batch need fallback & which should be skipped
needs_fallback[i], should_skip[i] = self._need_fallback( needs_fallback[i], should_skip[i] = self._need_fallback(
@@ -878,15 +883,18 @@ class WhisperGenerationMixin:
seek_outputs["token_timestamps"] = self._extract_token_timestamps( seek_outputs["token_timestamps"] = self._extract_token_timestamps(
seek_outputs, generation_config.alignment_heads, num_frames=num_frames seek_outputs, generation_config.alignment_heads, num_frames=num_frames
) )
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, decoder_input_ids.shape[-1] :]
seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :] seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :]
def split_by_batch_index(values, key, batch_idx): def split_by_batch_index(values, key, batch_idx):
if key == "scores": if key == "scores":
return [v[batch_idx].cpu() for v in values] return [v[batch_idx].cpu() for v in values]
if key == "past_key_values": elif key == "past_key_values":
# we don't save `past_key_values` as this is too costly # we don't save `past_key_values` as this is too costly
return None return None
elif isinstance(values[batch_idx], tuple) and torch.is_tensor(values[batch_idx][0]):
return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
return values[batch_idx].cpu() return values[batch_idx].cpu()
sequence_tokens = seek_outputs["sequences"] sequence_tokens = seek_outputs["sequences"]
@@ -1611,6 +1619,7 @@ class WhisperGenerationMixin:
input_stride, input_stride,
prev_idx, prev_idx,
idx, idx,
return_token_timestamps,
): ):
# find the predicted "end of segment" predictions of Whisper # find the predicted "end of segment" predictions of Whisper
# "end of segment" predictions occur whenever Whisper predicts a timestamp token # "end of segment" predictions occur whenever Whisper predicts a timestamp token
@@ -1618,6 +1627,7 @@ class WhisperGenerationMixin:
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
timestamp_segment_indices.add_(1) timestamp_segment_indices.add_(1)
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
# If whisper predicted a "end of segment" via a timestep token, let's go ever each # If whisper predicted a "end of segment" via a timestep token, let's go ever each
# "end of segment" prediction and slice the decoding into segments accordingly # "end of segment" prediction and slice the decoding into segments accordingly
@@ -1642,6 +1652,10 @@ class WhisperGenerationMixin:
"result": seek_outputs[idx], "result": seek_outputs[idx],
} }
) )
if return_token_timestamps:
segments[-1]["token_timestamps"] = (
token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
)
last_slice = current_slice last_slice = current_slice
if single_timestamp_ending: if single_timestamp_ending:
@@ -1661,7 +1675,6 @@ class WhisperGenerationMixin:
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin: if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one. # no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - timestamp_begin last_timestamp_pos = timestamps[-1].item() - timestamp_begin
segments = [ segments = [
{ {
"start": time_offset[prev_idx], "start": time_offset[prev_idx],
@@ -1670,6 +1683,8 @@ class WhisperGenerationMixin:
"result": seek_outputs[idx], "result": seek_outputs[idx],
} }
] ]
if return_token_timestamps:
segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
segment_offset = seek_num_frames[prev_idx] segment_offset = seek_num_frames[prev_idx]
return segments, segment_offset return segments, segment_offset

View File

@@ -483,6 +483,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
generate_kwargs["return_timestamps"] = return_timestamps generate_kwargs["return_timestamps"] = return_timestamps
if return_timestamps == "word": if return_timestamps == "word":
generate_kwargs["return_token_timestamps"] = True generate_kwargs["return_token_timestamps"] = True
generate_kwargs["return_segments"] = True
if stride is not None: if stride is not None:
if isinstance(stride, tuple): if isinstance(stride, tuple):
@@ -499,8 +500,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
attention_mask=attention_mask, attention_mask=attention_mask,
**generate_kwargs, **generate_kwargs,
) )
# whisper longform generation stores timestamps in "segments"
if return_timestamps == "word" and self.type == "seq2seq_whisper": if return_timestamps == "word" and self.type == "seq2seq_whisper":
if "segments" not in tokens:
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]} out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
else:
token_timestamps = [
torch.cat([segment["token_timestamps"] for segment in segment_list])
for segment_list in tokens["segments"]
]
out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps}
else: else:
out = {"tokens": tokens} out = {"tokens": tokens}
if self.type == "seq2seq_whisper": if self.type == "seq2seq_whisper":

View File

@@ -1969,6 +1969,56 @@ class WhisperModelIntegrationTests(unittest.TestCase):
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples) self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
@slow
def test_tiny_token_timestamp_generation_longform(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
input_speech = self._load_datasamples(5)
long_input_speech = np.concatenate(input_speech, dtype=np.float32)
inputs = processor.feature_extractor(
raw_speech=long_input_speech,
return_tensors="pt",
truncation=False, # False so the audio isn't truncated and whole audio is sent to the model
return_attention_mask=True,
padding=True,
)
inputs = inputs.to(torch_device)
generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True)
token_timestamps_shape = [
[segment["token_timestamps"].shape for segment in segment_list]
for segment_list in generate_outputs["segments"]
]
tokens_shape = [
[segment["tokens"].shape for segment in segment_list] for segment_list in generate_outputs["segments"]
]
self.assertListEqual(tokens_shape, token_timestamps_shape)
# fmt: off
EXPECTED_OUTPUT = [
torch.tensor([0.0000, 0.4200, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5200, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600]),
torch.tensor([ 6.5400, 6.5400, 6.7400, 6.9600, 7.2600, 7.3400, 7.5800, 7.5800, 7.6400, 7.8400, 8.1000, 8.5000, 9.0000, 9.4800, 9.7200, 10.2600, 11.1000]),
torch.tensor([11.2200, 11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1800, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800]),
torch.tensor([16.9800, 16.9800, 17.3200, 18.1600, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8600, 23.7000]),
torch.tensor([23.7000, 23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.4000, 26.5800, 26.7600, 27.1400, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200]),
torch.tensor([29.4400, 29.4400, 29.7000, 30.0800, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.4800, 32.6200, 33.6800]),
torch.tensor([33.8000, 33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600]),
torch.tensor([40.5200, 40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000]),
torch.tensor([44.7000, 44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400]),
torch.tensor([50.5400, 50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400]),
torch.tensor([52.9600, 52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1800, 58.5000, 58.6400, 58.8200]),
torch.tensor([58.6800, 58.6800, 59.1400, 59.5400, 59.9200, 60.1600, 60.3800, 60.8200, 61.6200, 62.2600, 75.2000]),
]
# fmt: on
for segment, exp_segment in zip(generate_outputs["segments"][0], EXPECTED_OUTPUT):
self.assertTrue(torch.allclose(segment["token_timestamps"], exp_segment))
@slow @slow
def test_tiny_specaugment_librispeech(self): def test_tiny_specaugment_librispeech(self):
torch_device = "cpu" torch_device = "cpu"

View File

@@ -361,6 +361,70 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
) )
# fmt: on # fmt: on
@slow
@require_torch
def test_return_timestamps_in_preprocess_longform(self):
pipe = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny.en",
)
data = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
samples = [next(iter(data)) for _ in range(8)]
audio = np.concatenate([sample["audio"]["array"] for sample in samples])
res = pipe(audio)
expected_output = {
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents."
}
self.assertEqual(res, expected_output)
res = pipe(audio, return_timestamps=True)
self.assertEqual(
res,
{
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents.",
"chunks": [
{"timestamp": (0.0, 3.22), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (3.22, 6.74), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (6.74, 10.26), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (10.26, 13.78), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (13.78, 17.3), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (17.3, 20.82), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (20.82, 24.34), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (24.34, 27.86), "text": " Concord returned to its place amidst the tents."},
],
},
)
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
res = pipe(audio, return_timestamps="word")
# fmt: off
self.assertEqual(
res["chunks"][:15],
[
{"text": " Concord", "timestamp": (0.5, 0.94)},
{"text": " returned", "timestamp": (0.94, 1.52)},
{"text": " to", "timestamp": (1.52, 1.78)},
{"text": " its", "timestamp": (1.78, 1.98)},
{"text": " place", "timestamp": (1.98, 2.16)},
{"text": " amidst", "timestamp": (2.16, 2.5)},
{"text": " the", "timestamp": (2.5, 2.9)},
{"text": " tents.", "timestamp": (2.9, 4.2)},
{"text": " Concord", "timestamp": (4.2, 4.5)},
{"text": " returned", "timestamp": (4.5, 5.0)},
{"text": " to", "timestamp": (5.0, 5.28)},
{"text": " its", "timestamp": (5.28, 5.48)},
{"text": " place", "timestamp": (5.48, 5.7)},
{"text": " amidst", "timestamp": (5.7, 6.02)},
{"text": " the", "timestamp": (6.02, 6.4)}
],
)
# fmt: on
@require_torch @require_torch
def test_return_timestamps_in_init(self): def test_return_timestamps_in_init(self):
# segment-level timestamps are accepted # segment-level timestamps are accepted