From 0dff407d71401f22d80e50c9513d0d3bac8b1cdb Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 27 Jan 2023 16:35:16 +0100 Subject: [PATCH] [Whisper] another patch (#21324) * another patch * fix timestamp test modeling * let it be negative when the token is None --- src/transformers/generation/tf_logits_process.py | 3 ++- src/transformers/pipelines/automatic_speech_recognition.py | 2 +- tests/models/whisper/test_modeling_whisper.py | 6 ++---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/tf_logits_process.py b/src/transformers/generation/tf_logits_process.py index d446f1f6f1..19a14c63b2 100644 --- a/src/transformers/generation/tf_logits_process.py +++ b/src/transformers/generation/tf_logits_process.py @@ -557,7 +557,8 @@ class TFForceTokensLogitsProcessor(TFLogitsProcessor): # Indexes without forced tokens will have an negative value. force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1 for index, token in force_token_map.items(): - force_token_array[index] = token + if token is not None: + force_token_array[index] = token self.force_token_array = tf.convert_to_tensor(force_token_array, dtype=tf.int32) def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index a409ff21c3..8c552cbdc3 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -101,7 +101,7 @@ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source chunk_len, stride_left, stride_right = stride sequence = sequence.squeeze(0) # get rid of the `forced_decoder_idx` that are use to parametrize the generation - begin_idx = np.where(sequence == timestamp_begin)[0].item() if timestamp_begin in sequence else 0 + begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0 sequence = sequence[begin_idx:] timestamp_tokens = sequence >= timestamp_begin diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index e9a6f0705f..0944a3a646 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -46,7 +46,6 @@ if is_torch_available(): WhisperProcessor, set_seed, ) - from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder @@ -1077,9 +1076,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to( torch_device ) - model.config.forced_decoder_ids = [(1, 50259), (2, 50359), (3, 50364)] - timestamp_processor = [WhisperTimeStampLogitsProcessor(len(model.config.forced_decoder_ids))] - generated_ids = model.generate(input_features, max_length=448, logits_processor=timestamp_processor).to("cpu") + + generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu") # fmt: off EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257])