Add tokenizers class mismatch detection between cls and checkpoint (#12619)

* Detect mismatch by analyzing config

* Fix comment

* Fix import

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>

* Revise based on reviews

* remove kwargs

* Fix exception

* Fix handling exception again

* Disable mismatch test in PreTrainedTokenizerFast

Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>
This commit is contained in:
Tomohiro Endo
2021-07-17 22:52:21 +09:00
committed by GitHub
parent b4b562d834
commit 08d609bfb8
5 changed files with 110 additions and 1 deletions

View File

@@ -22,6 +22,7 @@ from transformers import AutoTokenizer
from transformers.models.bert_japanese.tokenization_bert_japanese import (
VOCAB_FILES_NAMES,
BertJapaneseTokenizer,
BertTokenizer,
CharacterTokenizer,
MecabTokenizer,
WordpieceTokenizer,
@@ -278,3 +279,23 @@ class AutoTokenizerCustomTest(unittest.TestCase):
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
self.assertIsInstance(tokenizer, BertJapaneseTokenizer)
class BertTokenizerMismatchTest(unittest.TestCase):
def test_tokenizer_mismatch_warning(self):
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
with self.assertLogs("transformers", level="WARNING") as cm:
BertTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
EXAMPLE_BERT_ID = "bert-base-cased"
with self.assertLogs("transformers", level="WARNING") as cm:
BertJapaneseTokenizer.from_pretrained(EXAMPLE_BERT_ID)
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)