From 35c155052d49bc6429922a6cf946fafa617a87b5 Mon Sep 17 00:00:00 2001 From: CL-ModelCloud Date: Thu, 13 Feb 2025 19:00:33 +0800 Subject: [PATCH] Fix PretrainedTokenizerFast check => Fix PretrainedTokenizerFast Save (#35835) * Fix the bug in tokenizer.save_pretrained when saving tokenizer_class to tokenizer_config.json * Update tokenization_utils_base.py * Update tokenization_utils_base.py * Update tokenization_utils_base.py * add tokenizer class type test * code review * code opt * fix bug * Update test_tokenization_fast.py * ruff check * make style * code opt * Update test_tokenization_fast.py --------- Co-authored-by: Qubitium-ModelCloud Co-authored-by: LRL-ModelCloud <165116337+LRL-ModelCloud@users.noreply.github.com> --- src/transformers/tokenization_utils_base.py | 4 +-- tests/tokenization/test_tokenization_fast.py | 37 +++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 7ad36ab017..fc31e72984 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2472,8 +2472,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): # Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained tokenizer_class = self.__class__.__name__ - # Remove the Fast at the end unless we have a special `PreTrainedTokenizerFast` - if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast": + # Remove the Fast at the end if we can save the slow tokenizer + if tokenizer_class.endswith("Fast") and getattr(self, "can_save_slow_tokenizer", False): tokenizer_class = tokenizer_class[:-4] tokenizer_config["tokenizer_class"] = tokenizer_class if getattr(self, "_auto_map", None) is not None: diff --git a/tests/tokenization/test_tokenization_fast.py b/tests/tokenization/test_tokenization_fast.py index d5c6444de4..4bd9b046d4 100644 --- a/tests/tokenization/test_tokenization_fast.py +++ b/tests/tokenization/test_tokenization_fast.py @@ -20,7 +20,7 @@ import shutil import tempfile import unittest -from transformers import AutoTokenizer, PreTrainedTokenizerFast +from transformers import AutoTokenizer, LlamaTokenizerFast, PreTrainedTokenizerFast from transformers.testing_utils import require_tokenizers from ..test_tokenization_common import TokenizerTesterMixin @@ -170,6 +170,41 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase): # thus tok(sentences, truncation = True) does nothing and does not warn either self.assertEqual(tok(sentences, truncation = True, max_length = 8), {'input_ids': [[8774, 6, 3, 63, 31, 1748, 55, 1],[ 571, 33, 25, 3, 2, 3, 58, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1]]}) # fmt: skip + def test_class_after_save_and_reload(self): + # Model contains a `LlamaTokenizerFast` tokenizer with no slow fallback + model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + with tempfile.TemporaryDirectory() as temp_dir: + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) + self.assertTrue( + isinstance(tokenizer, LlamaTokenizerFast), + f"Expected tokenizer(use_fast=True) type: `LlamaTokenizerFast`, actual=`{type(tokenizer)}`", + ) + + # Fast tokenizer will ignore `use_fast=False` + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) + self.assertTrue( + isinstance(tokenizer, LlamaTokenizerFast), + f"Expected tokenizer type(use_fast=False): `LlamaTokenizerFast`, actual=`{type(tokenizer)}`", + ) + + # Save tokenizer + tokenizer.save_pretrained(temp_dir) + + tokenizer = AutoTokenizer.from_pretrained(temp_dir, use_fast=False) + # Verify post save and reload the fast tokenizer class did not change + self.assertTrue( + isinstance(tokenizer, LlamaTokenizerFast), + f"Expected tokenizer type: `LlamaTokenizerFast`, actual=`{type(tokenizer)}`", + ) + + tokenizer = AutoTokenizer.from_pretrained(temp_dir, use_fast=True) + # Verify post save and reload the fast tokenizer class did not change + self.assertTrue( + isinstance(tokenizer, LlamaTokenizerFast), + f"Expected tokenizer type: `LlamaTokenizerFast`, actual=`{type(tokenizer)}`", + ) + @require_tokenizers class TokenizerVersioningTest(unittest.TestCase):