From ddf7ac4237cfa08c50e65c297f7afa97a093fa91 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 27 Feb 2024 23:15:26 +0500 Subject: [PATCH] Token level timestamps for long-form generation in Whisper (#29148) --- .../models/whisper/generation_whisper.py | 19 +++++- .../pipelines/automatic_speech_recognition.py | 11 +++- tests/models/whisper/test_modeling_whisper.py | 50 +++++++++++++++ ..._pipelines_automatic_speech_recognition.py | 64 +++++++++++++++++++ 4 files changed, 141 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 0d6addb563..5b5957d534 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -720,6 +720,7 @@ class WhisperGenerationMixin: input_stride=input_stride, prev_idx=prev_i, idx=i, + return_token_timestamps=return_token_timestamps, ) current_segments[prev_i] += segments @@ -809,11 +810,15 @@ class WhisperGenerationMixin: # remove eos token id if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: seek_sequence = seek_sequence[:-1] + if return_token_timestamps: + seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1] # remove all padding tokens if seek_sequence[-1] == generation_config.pad_token_id: num_paddings = (seek_sequence == generation_config.pad_token_id).sum() seek_sequence = seek_sequence[:-num_paddings] + if return_token_timestamps: + seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings] # check which sequences in batch need fallback & which should be skipped needs_fallback[i], should_skip[i] = self._need_fallback( @@ -878,15 +883,18 @@ class WhisperGenerationMixin: seek_outputs["token_timestamps"] = self._extract_token_timestamps( seek_outputs, generation_config.alignment_heads, num_frames=num_frames ) + seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, decoder_input_ids.shape[-1] :] seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :] def split_by_batch_index(values, key, batch_idx): if key == "scores": return [v[batch_idx].cpu() for v in values] - if key == "past_key_values": + elif key == "past_key_values": # we don't save `past_key_values` as this is too costly return None + elif isinstance(values[batch_idx], tuple) and torch.is_tensor(values[batch_idx][0]): + return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values) return values[batch_idx].cpu() sequence_tokens = seek_outputs["sequences"] @@ -1611,6 +1619,7 @@ class WhisperGenerationMixin: input_stride, prev_idx, idx, + return_token_timestamps, ): # find the predicted "end of segment" predictions of Whisper # "end of segment" predictions occur whenever Whisper predicts a timestamp token @@ -1618,6 +1627,7 @@ class WhisperGenerationMixin: single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] timestamp_segment_indices.add_(1) + token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else [] # If whisper predicted a "end of segment" via a timestep token, let's go ever each # "end of segment" prediction and slice the decoding into segments accordingly @@ -1642,6 +1652,10 @@ class WhisperGenerationMixin: "result": seek_outputs[idx], } ) + if return_token_timestamps: + segments[-1]["token_timestamps"] = ( + token_timestamps[last_slice:current_slice] + time_offset[prev_idx] + ) last_slice = current_slice if single_timestamp_ending: @@ -1661,7 +1675,6 @@ class WhisperGenerationMixin: if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin: # no consecutive timestamps but it has a timestamp; use the last one. last_timestamp_pos = timestamps[-1].item() - timestamp_begin - segments = [ { "start": time_offset[prev_idx], @@ -1670,6 +1683,8 @@ class WhisperGenerationMixin: "result": seek_outputs[idx], } ] + if return_token_timestamps: + segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx] segment_offset = seek_num_frames[prev_idx] return segments, segment_offset diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 5e392502c9..ee976e9ece 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -483,6 +483,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): generate_kwargs["return_timestamps"] = return_timestamps if return_timestamps == "word": generate_kwargs["return_token_timestamps"] = True + generate_kwargs["return_segments"] = True if stride is not None: if isinstance(stride, tuple): @@ -499,8 +500,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): attention_mask=attention_mask, **generate_kwargs, ) + # whisper longform generation stores timestamps in "segments" if return_timestamps == "word" and self.type == "seq2seq_whisper": - out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]} + if "segments" not in tokens: + out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]} + else: + token_timestamps = [ + torch.cat([segment["token_timestamps"] for segment in segment_list]) + for segment_list in tokens["segments"] + ] + out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps} else: out = {"tokens": tokens} if self.type == "seq2seq_whisper": diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 1f92f1523d..dc24a5bc34 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1969,6 +1969,56 @@ class WhisperModelIntegrationTests(unittest.TestCase): self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples) + @slow + def test_tiny_token_timestamp_generation_longform(self): + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model.to(torch_device) + model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]] + + input_speech = self._load_datasamples(5) + long_input_speech = np.concatenate(input_speech, dtype=np.float32) + inputs = processor.feature_extractor( + raw_speech=long_input_speech, + return_tensors="pt", + truncation=False, # False so the audio isn't truncated and whole audio is sent to the model + return_attention_mask=True, + padding=True, + ) + + inputs = inputs.to(torch_device) + generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True) + + token_timestamps_shape = [ + [segment["token_timestamps"].shape for segment in segment_list] + for segment_list in generate_outputs["segments"] + ] + tokens_shape = [ + [segment["tokens"].shape for segment in segment_list] for segment_list in generate_outputs["segments"] + ] + self.assertListEqual(tokens_shape, token_timestamps_shape) + + # fmt: off + EXPECTED_OUTPUT = [ + torch.tensor([0.0000, 0.4200, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5200, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600]), + torch.tensor([ 6.5400, 6.5400, 6.7400, 6.9600, 7.2600, 7.3400, 7.5800, 7.5800, 7.6400, 7.8400, 8.1000, 8.5000, 9.0000, 9.4800, 9.7200, 10.2600, 11.1000]), + torch.tensor([11.2200, 11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1800, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800]), + torch.tensor([16.9800, 16.9800, 17.3200, 18.1600, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8600, 23.7000]), + torch.tensor([23.7000, 23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.4000, 26.5800, 26.7600, 27.1400, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200]), + torch.tensor([29.4400, 29.4400, 29.7000, 30.0800, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.4800, 32.6200, 33.6800]), + torch.tensor([33.8000, 33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600]), + torch.tensor([40.5200, 40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000]), + torch.tensor([44.7000, 44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400]), + torch.tensor([50.5400, 50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400]), + torch.tensor([52.9600, 52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1800, 58.5000, 58.6400, 58.8200]), + torch.tensor([58.6800, 58.6800, 59.1400, 59.5400, 59.9200, 60.1600, 60.3800, 60.8200, 61.6200, 62.2600, 75.2000]), + ] + # fmt: on + + for segment, exp_segment in zip(generate_outputs["segments"][0], EXPECTED_OUTPUT): + self.assertTrue(torch.allclose(segment["token_timestamps"], exp_segment)) + @slow def test_tiny_specaugment_librispeech(self): torch_device = "cpu" diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 42cb7e50c2..d2af7e4468 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -361,6 +361,70 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ) # fmt: on + @slow + @require_torch + def test_return_timestamps_in_preprocess_longform(self): + pipe = pipeline( + task="automatic-speech-recognition", + model="openai/whisper-tiny.en", + ) + data = load_dataset("librispeech_asr", "clean", split="test", streaming=True) + samples = [next(iter(data)) for _ in range(8)] + audio = np.concatenate([sample["audio"]["array"] for sample in samples]) + + res = pipe(audio) + expected_output = { + "text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst " + "the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst " + "the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst " + "the tents. Concord returned to its place amidst the tents." + } + self.assertEqual(res, expected_output) + res = pipe(audio, return_timestamps=True) + self.assertEqual( + res, + { + "text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents.", + "chunks": [ + {"timestamp": (0.0, 3.22), "text": " Concord returned to its place amidst the tents."}, + {"timestamp": (3.22, 6.74), "text": " Concord returned to its place amidst the tents."}, + {"timestamp": (6.74, 10.26), "text": " Concord returned to its place amidst the tents."}, + {"timestamp": (10.26, 13.78), "text": " Concord returned to its place amidst the tents."}, + {"timestamp": (13.78, 17.3), "text": " Concord returned to its place amidst the tents."}, + {"timestamp": (17.3, 20.82), "text": " Concord returned to its place amidst the tents."}, + {"timestamp": (20.82, 24.34), "text": " Concord returned to its place amidst the tents."}, + {"timestamp": (24.34, 27.86), "text": " Concord returned to its place amidst the tents."}, + ], + }, + ) + pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]] + res = pipe(audio, return_timestamps="word") + + # fmt: off + self.assertEqual( + res["chunks"][:15], + [ + {"text": " Concord", "timestamp": (0.5, 0.94)}, + {"text": " returned", "timestamp": (0.94, 1.52)}, + {"text": " to", "timestamp": (1.52, 1.78)}, + {"text": " its", "timestamp": (1.78, 1.98)}, + {"text": " place", "timestamp": (1.98, 2.16)}, + {"text": " amidst", "timestamp": (2.16, 2.5)}, + {"text": " the", "timestamp": (2.5, 2.9)}, + {"text": " tents.", "timestamp": (2.9, 4.2)}, + {"text": " Concord", "timestamp": (4.2, 4.5)}, + {"text": " returned", "timestamp": (4.5, 5.0)}, + {"text": " to", "timestamp": (5.0, 5.28)}, + {"text": " its", "timestamp": (5.28, 5.48)}, + {"text": " place", "timestamp": (5.48, 5.7)}, + {"text": " amidst", "timestamp": (5.7, 6.02)}, + {"text": " the", "timestamp": (6.02, 6.4)} + + + ], + ) + # fmt: on + @require_torch def test_return_timestamps_in_init(self): # segment-level timestamps are accepted