Add tokenizers class mismatch detection between cls and checkpoint (#12619)
* Detect mismatch by analyzing config * Fix comment * Fix import * Update src/transformers/tokenization_utils_base.py Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com> * Revise based on reviews * remove kwargs * Fix exception * Fix handling exception again * Disable mismatch test in PreTrainedTokenizerFast Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from transformers import AutoTokenizer
|
||||
from transformers.models.bert_japanese.tokenization_bert_japanese import (
|
||||
VOCAB_FILES_NAMES,
|
||||
BertJapaneseTokenizer,
|
||||
BertTokenizer,
|
||||
CharacterTokenizer,
|
||||
MecabTokenizer,
|
||||
WordpieceTokenizer,
|
||||
@@ -278,3 +279,23 @@ class AutoTokenizerCustomTest(unittest.TestCase):
|
||||
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
|
||||
tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
|
||||
self.assertIsInstance(tokenizer, BertJapaneseTokenizer)
|
||||
|
||||
|
||||
class BertTokenizerMismatchTest(unittest.TestCase):
|
||||
def test_tokenizer_mismatch_warning(self):
|
||||
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
BertTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
|
||||
)
|
||||
)
|
||||
EXAMPLE_BERT_ID = "bert-base-cased"
|
||||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||
BertJapaneseTokenizer.from_pretrained(EXAMPLE_BERT_ID)
|
||||
self.assertTrue(
|
||||
cm.records[0].message.startswith(
|
||||
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user