From 7adce8b53241bd0f537c5135541fbc166dbe9be8 Mon Sep 17 00:00:00 2001 From: Connor Henderson Date: Wed, 31 May 2023 10:52:35 -0400 Subject: [PATCH] fix: Replace `add_prefix_space` in `get_prompt_ids` with manual space for FastTokenizer compatibility (#23796) * add ' ' replacement for add_prefix_space * add fast tokenizer test --- .../models/whisper/tokenization_whisper.py | 2 +- .../models/whisper/tokenization_whisper_fast.py | 2 +- tests/models/whisper/test_tokenization_whisper.py | 10 ++++++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 4c7c9c89fd..428254a26a 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -721,7 +721,7 @@ class WhisperTokenizer(PreTrainedTokenizer): def get_prompt_ids(self, text: str, return_tensors="np"): """Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`].""" - batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False) + batch_encoding = self("<|startofprev|>", " " + text.strip(), add_special_tokens=False) # Check for special tokens prompt_text_ids = batch_encoding["input_ids"][1:] diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index be4ad842a7..a31fe00056 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -494,7 +494,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.get_prompt_ids def get_prompt_ids(self, text: str, return_tensors="np"): """Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`].""" - batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False) + batch_encoding = self("<|startofprev|>", " " + text.strip(), add_special_tokens=False) # Check for special tokens prompt_text_ids = batch_encoding["input_ids"][1:] diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 5022d29b73..09c98db317 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -213,6 +213,16 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens ) + def test_fast_tokenizer_get_prompt_ids(self): + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer() + + prompt = "This is test prompt text." + tokenizer_prompt_ids = tokenizer.get_prompt_ids(prompt) + fast_tokenizer_prompt_ids = rust_tokenizer.get_prompt_ids(prompt) + + self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist()) + class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): checkpoint_name = "openai/whisper-small.en"