Fix PretrainedTokenizerFast check => Fix PretrainedTokenizerFast Save (#35835)
* Fix the bug in tokenizer.save_pretrained when saving tokenizer_class to tokenizer_config.json * Update tokenization_utils_base.py * Update tokenization_utils_base.py * Update tokenization_utils_base.py * add tokenizer class type test * code review * code opt * fix bug * Update test_tokenization_fast.py * ruff check * make style * code opt * Update test_tokenization_fast.py --------- Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai> Co-authored-by: LRL-ModelCloud <165116337+LRL-ModelCloud@users.noreply.github.com>
This commit is contained in:
@@ -2472,8 +2472,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
|
|
||||||
# Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained
|
# Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained
|
||||||
tokenizer_class = self.__class__.__name__
|
tokenizer_class = self.__class__.__name__
|
||||||
# Remove the Fast at the end unless we have a special `PreTrainedTokenizerFast`
|
# Remove the Fast at the end if we can save the slow tokenizer
|
||||||
if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast":
|
if tokenizer_class.endswith("Fast") and getattr(self, "can_save_slow_tokenizer", False):
|
||||||
tokenizer_class = tokenizer_class[:-4]
|
tokenizer_class = tokenizer_class[:-4]
|
||||||
tokenizer_config["tokenizer_class"] = tokenizer_class
|
tokenizer_config["tokenizer_class"] = tokenizer_class
|
||||||
if getattr(self, "_auto_map", None) is not None:
|
if getattr(self, "_auto_map", None) is not None:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
from transformers import AutoTokenizer, LlamaTokenizerFast, PreTrainedTokenizerFast
|
||||||
from transformers.testing_utils import require_tokenizers
|
from transformers.testing_utils import require_tokenizers
|
||||||
|
|
||||||
from ..test_tokenization_common import TokenizerTesterMixin
|
from ..test_tokenization_common import TokenizerTesterMixin
|
||||||
@@ -170,6 +170,41 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
# thus tok(sentences, truncation = True) does nothing and does not warn either
|
# thus tok(sentences, truncation = True) does nothing and does not warn either
|
||||||
self.assertEqual(tok(sentences, truncation = True, max_length = 8), {'input_ids': [[8774, 6, 3, 63, 31, 1748, 55, 1],[ 571, 33, 25, 3, 2, 3, 58, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1]]}) # fmt: skip
|
self.assertEqual(tok(sentences, truncation = True, max_length = 8), {'input_ids': [[8774, 6, 3, 63, 31, 1748, 55, 1],[ 571, 33, 25, 3, 2, 3, 58, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1]]}) # fmt: skip
|
||||||
|
|
||||||
|
def test_class_after_save_and_reload(self):
|
||||||
|
# Model contains a `LlamaTokenizerFast` tokenizer with no slow fallback
|
||||||
|
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||||
|
self.assertTrue(
|
||||||
|
isinstance(tokenizer, LlamaTokenizerFast),
|
||||||
|
f"Expected tokenizer(use_fast=True) type: `LlamaTokenizerFast`, actual=`{type(tokenizer)}`",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fast tokenizer will ignore `use_fast=False`
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||||
|
self.assertTrue(
|
||||||
|
isinstance(tokenizer, LlamaTokenizerFast),
|
||||||
|
f"Expected tokenizer type(use_fast=False): `LlamaTokenizerFast`, actual=`{type(tokenizer)}`",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save tokenizer
|
||||||
|
tokenizer.save_pretrained(temp_dir)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(temp_dir, use_fast=False)
|
||||||
|
# Verify post save and reload the fast tokenizer class did not change
|
||||||
|
self.assertTrue(
|
||||||
|
isinstance(tokenizer, LlamaTokenizerFast),
|
||||||
|
f"Expected tokenizer type: `LlamaTokenizerFast`, actual=`{type(tokenizer)}`",
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(temp_dir, use_fast=True)
|
||||||
|
# Verify post save and reload the fast tokenizer class did not change
|
||||||
|
self.assertTrue(
|
||||||
|
isinstance(tokenizer, LlamaTokenizerFast),
|
||||||
|
f"Expected tokenizer type: `LlamaTokenizerFast`, actual=`{type(tokenizer)}`",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
class TokenizerVersioningTest(unittest.TestCase):
|
class TokenizerVersioningTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user