fix: Replace add_prefix_space in get_prompt_ids with manual space for FastTokenizer compatibility (#23796)
* add ' ' replacement for add_prefix_space * add fast tokenizer test
This commit is contained in:
@@ -721,7 +721,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def get_prompt_ids(self, text: str, return_tensors="np"):
|
def get_prompt_ids(self, text: str, return_tensors="np"):
|
||||||
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
|
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
|
||||||
batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False)
|
batch_encoding = self("<|startofprev|>", " " + text.strip(), add_special_tokens=False)
|
||||||
|
|
||||||
# Check for special tokens
|
# Check for special tokens
|
||||||
prompt_text_ids = batch_encoding["input_ids"][1:]
|
prompt_text_ids = batch_encoding["input_ids"][1:]
|
||||||
|
|||||||
@@ -494,7 +494,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.get_prompt_ids
|
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.get_prompt_ids
|
||||||
def get_prompt_ids(self, text: str, return_tensors="np"):
|
def get_prompt_ids(self, text: str, return_tensors="np"):
|
||||||
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
|
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
|
||||||
batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False)
|
batch_encoding = self("<|startofprev|>", " " + text.strip(), add_special_tokens=False)
|
||||||
|
|
||||||
# Check for special tokens
|
# Check for special tokens
|
||||||
prompt_text_ids = batch_encoding["input_ids"][1:]
|
prompt_text_ids = batch_encoding["input_ids"][1:]
|
||||||
|
|||||||
@@ -213,6 +213,16 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
|
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_fast_tokenizer_get_prompt_ids(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer()
|
||||||
|
|
||||||
|
prompt = "This is test prompt text."
|
||||||
|
tokenizer_prompt_ids = tokenizer.get_prompt_ids(prompt)
|
||||||
|
fast_tokenizer_prompt_ids = rust_tokenizer.get_prompt_ids(prompt)
|
||||||
|
|
||||||
|
self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||||
checkpoint_name = "openai/whisper-small.en"
|
checkpoint_name = "openai/whisper-small.en"
|
||||||
|
|||||||
Reference in New Issue
Block a user