Add tokenizers class mismatch detection between cls and checkpoint (#12619)
* Detect mismatch by analyzing config * Fix comment * Fix import * Update src/transformers/tokenization_utils_base.py Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com> * Revise based on reviews * remove kwargs * Fix exception * Fix handling exception again * Disable mismatch test in PreTrainedTokenizerFast Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>
This commit is contained in:
@@ -29,7 +29,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import (
|
||||
AlbertTokenizer,
|
||||
AlbertTokenizerFast,
|
||||
BertTokenizer,
|
||||
BertTokenizerFast,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
@@ -3292,6 +3295,41 @@ class TokenizerTesterMixin:
|
||||
expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result)
|
||||
self.assertEqual(expected_result, decoded_input)
|
||||
|
||||
def test_tokenizer_mismatch_warning(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
try:
|
||||
if self.tokenizer_class == BertTokenizer:
|
||||
AlbertTokenizer.from_pretrained(pretrained_name)
|
||||
else:
|
||||
BertTokenizer.from_pretrained(pretrained_name)
|
||||
except (TypeError, AttributeError):
|
||||
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
|
||||
# here we just check that the warning has been logged before the error is raised
|
||||
pass
|
||||
finally:
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
|
||||
)
|
||||
)
|
||||
try:
|
||||
if self.rust_tokenizer_class == BertTokenizerFast:
|
||||
AlbertTokenizerFast.from_pretrained(pretrained_name)
|
||||
else:
|
||||
BertTokenizerFast.from_pretrained(pretrained_name)
|
||||
except (TypeError, AttributeError):
|
||||
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
|
||||
# here we just check that the warning has been logged before the error is raised
|
||||
pass
|
||||
finally:
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class TokenizerPushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user