Add tokenizers class mismatch detection between cls and checkpoint (#12619)

* Detect mismatch by analyzing config

* Fix comment

* Fix import

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>

* Revise based on reviews

* remove kwargs

* Fix exception

* Fix handling exception again

* Disable mismatch test in PreTrainedTokenizerFast

Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>
This commit is contained in:
Tomohiro Endo
2021-07-17 22:52:21 +09:00
committed by GitHub
parent b4b562d834
commit 08d609bfb8
5 changed files with 110 additions and 1 deletions

View File

@@ -22,6 +22,7 @@ from transformers import AutoTokenizer
from transformers.models.bert_japanese.tokenization_bert_japanese import (
VOCAB_FILES_NAMES,
BertJapaneseTokenizer,
BertTokenizer,
CharacterTokenizer,
MecabTokenizer,
WordpieceTokenizer,
@@ -278,3 +279,23 @@ class AutoTokenizerCustomTest(unittest.TestCase):
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
self.assertIsInstance(tokenizer, BertJapaneseTokenizer)
class BertTokenizerMismatchTest(unittest.TestCase):
def test_tokenizer_mismatch_warning(self):
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
with self.assertLogs("transformers", level="WARNING") as cm:
BertTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
EXAMPLE_BERT_ID = "bert-base-cased"
with self.assertLogs("transformers", level="WARNING") as cm:
BertJapaneseTokenizer.from_pretrained(EXAMPLE_BERT_ID)
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)

View File

@@ -29,7 +29,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import (
AlbertTokenizer,
AlbertTokenizerFast,
BertTokenizer,
BertTokenizerFast,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
@@ -3292,6 +3295,41 @@ class TokenizerTesterMixin:
expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result)
self.assertEqual(expected_result, decoded_input)
def test_tokenizer_mismatch_warning(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
with self.assertLogs("transformers", level="WARNING") as cm:
try:
if self.tokenizer_class == BertTokenizer:
AlbertTokenizer.from_pretrained(pretrained_name)
else:
BertTokenizer.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
try:
if self.rust_tokenizer_class == BertTokenizerFast:
AlbertTokenizerFast.from_pretrained(pretrained_name)
else:
BertTokenizerFast.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):

View File

@@ -44,6 +44,11 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_paths[0])
tokenizer.save_pretrained(self.tmpdirname)
def test_tokenizer_mismatch_warning(self):
# We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any
# model
pass
def test_pretrained_model_lists(self):
# We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any
# model