Add tokenizers class mismatch detection between cls and checkpoint (#12619)

* Detect mismatch by analyzing config

* Fix comment

* Fix import

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>

* Revise based on reviews

* remove kwargs

* Fix exception

* Fix handling exception again

* Disable mismatch test in PreTrainedTokenizerFast

Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>
This commit is contained in:
Tomohiro Endo
2021-07-17 22:52:21 +09:00
committed by GitHub
parent b4b562d834
commit 08d609bfb8
5 changed files with 110 additions and 1 deletions

View File

@@ -29,7 +29,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import (
AlbertTokenizer,
AlbertTokenizerFast,
BertTokenizer,
BertTokenizerFast,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
@@ -3292,6 +3295,41 @@ class TokenizerTesterMixin:
expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result)
self.assertEqual(expected_result, decoded_input)
def test_tokenizer_mismatch_warning(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
with self.assertLogs("transformers", level="WARNING") as cm:
try:
if self.tokenizer_class == BertTokenizer:
AlbertTokenizer.from_pretrained(pretrained_name)
else:
BertTokenizer.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
try:
if self.rust_tokenizer_class == BertTokenizerFast:
AlbertTokenizerFast.from_pretrained(pretrained_name)
else:
BertTokenizerFast.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):