From 77382e918d717a91472ba634b9163aacaeaded38 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Wed, 7 Dec 2022 16:44:13 +0000 Subject: [PATCH] [Whisper] Fix forced decoder ids (#20652) * [Whisper] Fix forced decoder ids * fix test --- src/transformers/models/whisper/tokenization_whisper.py | 7 ++++++- tests/models/whisper/test_processor_whisper.py | 3 +-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index e10fb7c836..26c642c134 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -584,5 +584,10 @@ class WhisperTokenizer(PreTrainedTokenizer): def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True): self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps) - forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(self.prefix_tokens)] + # prefix tokens are of the form: <|startoftranscript|> <|lang_id|> <|task|> <|notimestamps|> + # we don't want to force the bos token at position 1, as this is the starting token + # when we generate, so we slice the prefix tokens to: <|lang_id|> <|task|> <|notimestamps|> + # to get the forced tokens + forced_tokens = self.prefix_tokens[1:] + forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)] return forced_decoder_ids diff --git a/tests/models/whisper/test_processor_whisper.py b/tests/models/whisper/test_processor_whisper.py index e941db7e35..b844d433ed 100644 --- a/tests/models/whisper/test_processor_whisper.py +++ b/tests/models/whisper/test_processor_whisper.py @@ -26,7 +26,6 @@ if is_speech_available(): from transformers import WhisperFeatureExtractor, WhisperProcessor -START_OF_TRANSCRIPT = 50257 TRANSCRIBE = 50358 NOTIMESTAMPS = 50362 @@ -145,5 +144,5 @@ class WhisperProcessorTest(unittest.TestCase): for ids in forced_decoder_ids: self.assertIsInstance(ids, (list, tuple)) - expected_ids = [START_OF_TRANSCRIPT, TRANSCRIBE, NOTIMESTAMPS] + expected_ids = [TRANSCRIBE, NOTIMESTAMPS] self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)