Add tokenizers class mismatch detection between cls and checkpoint (#12619)
* Detect mismatch by analyzing config * Fix comment * Fix import * Update src/transformers/tokenization_utils_base.py Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com> * Revise based on reviews * remove kwargs * Fix exception * Fix handling exception again * Disable mismatch test in PreTrainedTokenizerFast Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from transformers import AutoTokenizer
|
||||
from transformers.models.bert_japanese.tokenization_bert_japanese import (
|
||||
VOCAB_FILES_NAMES,
|
||||
BertJapaneseTokenizer,
|
||||
BertTokenizer,
|
||||
CharacterTokenizer,
|
||||
MecabTokenizer,
|
||||
WordpieceTokenizer,
|
||||
@@ -278,3 +279,23 @@ class AutoTokenizerCustomTest(unittest.TestCase):
|
||||
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
|
||||
tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
|
||||
self.assertIsInstance(tokenizer, BertJapaneseTokenizer)
|
||||
|
||||
|
||||
class BertTokenizerMismatchTest(unittest.TestCase):
|
||||
def test_tokenizer_mismatch_warning(self):
|
||||
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
BertTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
|
||||
)
|
||||
)
|
||||
EXAMPLE_BERT_ID = "bert-base-cased"
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
BertJapaneseTokenizer.from_pretrained(EXAMPLE_BERT_ID)
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
|
||||
)
|
||||
)
|
||||
|
||||
@@ -29,7 +29,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import (
|
||||
AlbertTokenizer,
|
||||
AlbertTokenizerFast,
|
||||
BertTokenizer,
|
||||
BertTokenizerFast,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
@@ -3292,6 +3295,41 @@ class TokenizerTesterMixin:
|
||||
expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result)
|
||||
self.assertEqual(expected_result, decoded_input)
|
||||
|
||||
def test_tokenizer_mismatch_warning(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
try:
|
||||
if self.tokenizer_class == BertTokenizer:
|
||||
AlbertTokenizer.from_pretrained(pretrained_name)
|
||||
else:
|
||||
BertTokenizer.from_pretrained(pretrained_name)
|
||||
except (TypeError, AttributeError):
|
||||
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
|
||||
# here we just check that the warning has been logged before the error is raised
|
||||
pass
|
||||
finally:
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
|
||||
)
|
||||
)
|
||||
try:
|
||||
if self.rust_tokenizer_class == BertTokenizerFast:
|
||||
AlbertTokenizerFast.from_pretrained(pretrained_name)
|
||||
else:
|
||||
BertTokenizerFast.from_pretrained(pretrained_name)
|
||||
except (TypeError, AttributeError):
|
||||
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
|
||||
# here we just check that the warning has been logged before the error is raised
|
||||
pass
|
||||
finally:
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class TokenizerPushToHubTester(unittest.TestCase):
|
||||
|
||||
@@ -44,6 +44,11 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_paths[0])
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def test_tokenizer_mismatch_warning(self):
|
||||
# We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any
|
||||
# model
|
||||
pass
|
||||
|
||||
def test_pretrained_model_lists(self):
|
||||
# We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any
|
||||
# model
|
||||
|
||||
Reference in New Issue
Block a user