From 0948c827de87056696032f1b84bad309259b0eb7 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Wed, 22 May 2024 17:25:47 +0100 Subject: [PATCH] [Whisper] Strip prompt before finding common subsequence (#27836) --- .../models/whisper/tokenization_whisper.py | 4 +++ ..._pipelines_automatic_speech_recognition.py | 36 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 5f7f0ae53d..303822de65 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -897,11 +897,15 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, right_stride_start = None all_special_ids = set(tokenizer.all_special_ids) + prompt_token_id = tokenizer.convert_tokens_to_ids("<|startofprev|>") + decoder_start_token_id = tokenizer.convert_tokens_to_ids("<|startoftranscript|>") # - iterate over all outputs for chunk_id, output in enumerate(model_outputs): # We can drop everything to Python list, it's going to make # our lives easier token_ids = output["tokens"][0].tolist() + # (possibly) remove the prompt from the token ids + token_ids = tokenizer._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id) if return_timestamps == "word": token_timestamps = output["token_timestamps"][0].tolist() diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index bc619769e1..5ab18e81d5 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -1343,6 +1343,42 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): self.assertEqual(output, [{"text": ANY(str)}]) self.assertEqual(output[0]["text"][:6], "