From 74fb524e204d36b11ea9c34b377c91491a6704e0 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Mon, 5 Dec 2022 18:45:22 +0000 Subject: [PATCH] [Whisper] Fix decoder ids methods (#20599) * [Whisper] Fix decoder ids methods * enum property --- .../models/whisper/tokenization_whisper.py | 5 +++-- .../models/whisper/test_processor_whisper.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 1f4590ba64..e10fb7c836 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -583,5 +583,6 @@ class WhisperTokenizer(PreTrainedTokenizer): return input_ids def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True): - self.set_prefix_tokens(task=task, language=language, predict_timestamps=no_timestamps) - return self.prefix_tokens + self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps) + forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(self.prefix_tokens)] + return forced_decoder_ids diff --git a/tests/models/whisper/test_processor_whisper.py b/tests/models/whisper/test_processor_whisper.py index bcdf1fb9f0..e941db7e35 100644 --- a/tests/models/whisper/test_processor_whisper.py +++ b/tests/models/whisper/test_processor_whisper.py @@ -26,6 +26,11 @@ if is_speech_available(): from transformers import WhisperFeatureExtractor, WhisperProcessor +START_OF_TRANSCRIPT = 50257 +TRANSCRIBE = 50358 +NOTIMESTAMPS = 50362 + + @require_torch @require_torchaudio @require_sentencepiece @@ -128,3 +133,17 @@ class WhisperProcessorTest(unittest.TestCase): feature_extractor.model_input_names, msg="`processor` and `feature_extractor` model input names do not match", ) + + def test_get_decoder_prompt_ids(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = WhisperProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", no_timestamps=True) + + self.assertIsInstance(forced_decoder_ids, list) + for ids in forced_decoder_ids: + self.assertIsInstance(ids, (list, tuple)) + + expected_ids = [START_OF_TRANSCRIPT, TRANSCRIBE, NOTIMESTAMPS] + self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)