From 5da3db3fd5c070107df717a13382ccf1fe1efbe4 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Fri, 22 Dec 2023 12:43:11 +0000 Subject: [PATCH] [Whisper] Fix word-level timestamps with bs>1 or num_beams>1 (#28114) * fix frames * use smaller chunk length * correct beam search + tentative stride * fix whisper word timestamp in batch * add test batch generation with return token timestamps * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * clean a test * make style + correct typo * write clearer comments * explain test in comment --------- Co-authored-by: sanchit-gandhi Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- .../models/whisper/modeling_whisper.py | 71 ++++++++++++++++--- .../pipelines/automatic_speech_recognition.py | 5 +- tests/models/whisper/test_modeling_whisper.py | 29 ++++++++ ..._pipelines_automatic_speech_recognition.py | 44 ++++++++++++ 4 files changed, 138 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index fb2bce476e..07c0e0afc1 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -2224,6 +2224,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): if return_token_timestamps: kwargs["output_attentions"] = True return_dict_in_generate = True + kwargs["output_scores"] = True if getattr(generation_config, "task", None) == "translate": logger.warning("Token-level timestamps may not be reliable for task 'translate'.") @@ -2555,22 +2556,72 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): # of shape (batch size, num selected, output length, input length). weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads]) weights = weights.permute([1, 0, 2, 3]) - if num_frames is not None: - weights = weights[..., : num_frames // 2] - # Normalize and smoothen the weights. - std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) - weights = (weights - mean) / std - weights = _median_filter(weights, self.config.median_filter_width) + if "beam_indices" in generate_outputs: + # If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths + # since the beam search strategy chooses the most probable sequences at the end of the search. + # In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length + weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() + weights = weights[:, :, :weight_length] - # Average the different cross-attention heads. - matrix = weights.mean(dim=1) + # If beam index is still -1, it means that the associated token id is EOS + # We need to replace the index with 0 since index_select gives an error if any of the indexes is -1. + beam_indices = generate_outputs.beam_indices[:, :weight_length] + beam_indices = beam_indices.masked_fill(beam_indices == -1, 0) + + # Select the cross attention from the right beam for each output sequences + weights = torch.stack( + [ + torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i]) + for i in range(beam_indices.shape[1]) + ], + dim=2, + ) timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32) + batch_size = timestamps.shape[0] + + if num_frames is not None: + # two cases: + # 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel + # 2. num_frames is different, compute the DTW matrix for each sample sequentially + + # we're using np.unique because num_frames can be int/list/tuple + if len(np.unique(num_frames)) == 1: + # if num_frames is the same, no need to recompute matrix, std and mean for each element of the batch + num_frames = num_frames if isinstance(num_frames, int) else num_frames[0] + + weights = weights[..., : num_frames // 2] + else: + # num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences + repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames) + num_frames = np.repeat(num_frames, repeat_time) + + if num_frames is None or isinstance(num_frames, int): + # Normalize and smoothen the weights. + std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) + weights = (weights - mean) / std + weights = _median_filter(weights, self.config.median_filter_width) + + # Average the different cross-attention heads. + weights = weights.mean(dim=1) # Perform dynamic time warping on each element of the batch. - for batch_idx in range(timestamps.shape[0]): - text_indices, time_indices = _dynamic_time_warping(-matrix[batch_idx].double().cpu().numpy()) + for batch_idx in range(batch_size): + if num_frames is not None and isinstance(num_frames, (tuple, list)): + matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2] + + # Normalize and smoothen the weights. + std, mean = torch.std_mean(matrix, dim=-2, keepdim=True, unbiased=False) + matrix = (matrix - mean) / std + matrix = _median_filter(matrix, self.config.median_filter_width) + + # Average the different cross-attention heads. + matrix = matrix.mean(dim=0) + else: + matrix = weights[batch_idx] + + text_indices, time_indices = _dynamic_time_warping(-matrix.double().cpu().numpy()) jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) jump_times = time_indices[jumps] * time_precision timestamps[batch_idx, 1:] = torch.tensor(jump_times) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 8fd1701d3c..32e61db42a 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -559,7 +559,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): generate_kwargs["return_token_timestamps"] = True if stride is not None: - generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length + if isinstance(stride, tuple): + generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length + else: + generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride] if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames: generate_kwargs["input_features"] = inputs diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 9de3b8ff2c..6ca7bfcb5d 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1850,6 +1850,35 @@ class WhisperModelIntegrationTests(unittest.TestCase): self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT)) + @slow + def test_tiny_token_timestamp_batch_generation(self): + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model.to(torch_device) + model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]] + + num_samples = 4 + num_return_sequences = 2 + + input_speech = self._load_datasamples(num_samples) + input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to( + torch_device + ) + + generate_outputs = model.generate( + input_features, + max_length=448, + return_timestamps=True, + return_token_timestamps=True, + num_beams=3, + num_return_sequences=num_return_sequences, + ) + + self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape) + + self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples) + @slow def test_tiny_specaugment_librispeech(self): torch_device = "cpu" diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 2ccaf71255..3da55ab9da 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -674,6 +674,50 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): }, ) + @slow + @require_torch + def test_whisper_word_timestamps_batched(self): + pipe = pipeline( + task="automatic-speech-recognition", + model="openai/whisper-tiny", + chunk_length_s=3, + return_timestamps="word", + ) + data = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + sample = data[0]["audio"] + + # not the same output as test_simple_whisper_asr because of chunking + EXPECTED_OUTPUT = { + "text": " Mr. Quilder is the apostle of the middle classes and we are glad to welcome his gospel.", + "chunks": [ + {"text": " Mr.", "timestamp": (0.48, 0.96)}, + {"text": " Quilder", "timestamp": (0.96, 1.24)}, + {"text": " is", "timestamp": (1.24, 1.5)}, + {"text": " the", "timestamp": (1.5, 1.72)}, + {"text": " apostle", "timestamp": (1.72, 1.98)}, + {"text": " of", "timestamp": (1.98, 2.32)}, + {"text": " the", "timestamp": (2.32, 2.5)}, + {"text": " middle", "timestamp": (2.5, 2.68)}, + {"text": " classes", "timestamp": (2.68, 3.2)}, + {"text": " and", "timestamp": (3.2, 3.56)}, + {"text": " we", "timestamp": (3.56, 3.68)}, + {"text": " are", "timestamp": (3.68, 3.8)}, + {"text": " glad", "timestamp": (3.8, 4.1)}, + {"text": " to", "timestamp": (4.1, 4.34)}, + {"text": " welcome", "timestamp": (4.3, 4.6)}, + {"text": " his", "timestamp": (4.6, 4.94)}, + {"text": " gospel.", "timestamp": (4.94, 5.82)}, + ], + } + + # batch size 1: copy the audio sample since pipeline consumes it + output = pipe(sample.copy(), batch_size=1) + self.assertDictEqual(output, EXPECTED_OUTPUT) + + # batch size 2: input audio is chunked into smaller pieces so it's testing batching + output = pipe(sample, batch_size=2) + self.assertDictEqual(output, EXPECTED_OUTPUT) + @require_torch @slow def test_torch_speech_encoder_decoder(self):