Fix PretrainedTokenizerFast check => Fix PretrainedTokenizerFast Save (#35835)
* Fix the bug in tokenizer.save_pretrained when saving tokenizer_class to tokenizer_config.json * Update tokenization_utils_base.py * Update tokenization_utils_base.py * Update tokenization_utils_base.py * add tokenizer class type test * code review * code opt * fix bug * Update test_tokenization_fast.py * ruff check * make style * code opt * Update test_tokenization_fast.py --------- Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai> Co-authored-by: LRL-ModelCloud <165116337+LRL-ModelCloud@users.noreply.github.com>
This commit is contained in:
@@ -20,7 +20,7 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
||||
from transformers import AutoTokenizer, LlamaTokenizerFast, PreTrainedTokenizerFast
|
||||
from transformers.testing_utils import require_tokenizers
|
||||
|
||||
from ..test_tokenization_common import TokenizerTesterMixin
|
||||
@@ -170,6 +170,41 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
# thus tok(sentences, truncation = True) does nothing and does not warn either
|
||||
self.assertEqual(tok(sentences, truncation = True, max_length = 8), {'input_ids': [[8774, 6, 3, 63, 31, 1748, 55, 1],[ 571, 33, 25, 3, 2, 3, 58, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1]]}) # fmt: skip
|
||||
|
||||
def test_class_after_save_and_reload(self):
|
||||
# Model contains a `LlamaTokenizerFast` tokenizer with no slow fallback
|
||||
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
self.assertTrue(
|
||||
isinstance(tokenizer, LlamaTokenizerFast),
|
||||
f"Expected tokenizer(use_fast=True) type: `LlamaTokenizerFast`, actual=`{type(tokenizer)}`",
|
||||
)
|
||||
|
||||
# Fast tokenizer will ignore `use_fast=False`
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
self.assertTrue(
|
||||
isinstance(tokenizer, LlamaTokenizerFast),
|
||||
f"Expected tokenizer type(use_fast=False): `LlamaTokenizerFast`, actual=`{type(tokenizer)}`",
|
||||
)
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer.save_pretrained(temp_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(temp_dir, use_fast=False)
|
||||
# Verify post save and reload the fast tokenizer class did not change
|
||||
self.assertTrue(
|
||||
isinstance(tokenizer, LlamaTokenizerFast),
|
||||
f"Expected tokenizer type: `LlamaTokenizerFast`, actual=`{type(tokenizer)}`",
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(temp_dir, use_fast=True)
|
||||
# Verify post save and reload the fast tokenizer class did not change
|
||||
self.assertTrue(
|
||||
isinstance(tokenizer, LlamaTokenizerFast),
|
||||
f"Expected tokenizer type: `LlamaTokenizerFast`, actual=`{type(tokenizer)}`",
|
||||
)
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
class TokenizerVersioningTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user