Add tokenizers class mismatch detection between cls and checkpoint (#12619)
* Detect mismatch by analyzing config * Fix comment * Fix import * Update src/transformers/tokenization_utils_base.py Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com> * Revise based on reviews * remove kwargs * Fix exception * Fix handling exception again * Disable mismatch test in PreTrainedTokenizerFast Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>
This commit is contained in:
@@ -132,7 +132,7 @@ class BertJapaneseTokenizer(BertTokenizer):
|
|||||||
if not os.path.isfile(vocab_file):
|
if not os.path.isfile(vocab_file):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
|
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
|
||||||
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
"model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||||
)
|
)
|
||||||
self.vocab = load_vocab(vocab_file)
|
self.vocab = load_vocab(vocab_file)
|
||||||
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
||||||
|
|||||||
@@ -1749,13 +1749,58 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
if tokenizer_config_file is not None:
|
if tokenizer_config_file is not None:
|
||||||
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
|
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
|
||||||
init_kwargs = json.load(tokenizer_config_handle)
|
init_kwargs = json.load(tokenizer_config_handle)
|
||||||
|
# First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
|
||||||
|
config_tokenizer_class = init_kwargs.get("tokenizer_class")
|
||||||
init_kwargs.pop("tokenizer_class", None)
|
init_kwargs.pop("tokenizer_class", None)
|
||||||
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
||||||
if not init_inputs:
|
if not init_inputs:
|
||||||
init_inputs = saved_init_inputs
|
init_inputs = saved_init_inputs
|
||||||
else:
|
else:
|
||||||
|
config_tokenizer_class = None
|
||||||
init_kwargs = init_configuration
|
init_kwargs = init_configuration
|
||||||
|
|
||||||
|
if config_tokenizer_class is None:
|
||||||
|
from .models.auto.configuration_auto import AutoConfig
|
||||||
|
|
||||||
|
# Second attempt. If we have not yet found tokenizer_class, let's try to use the config.
|
||||||
|
try:
|
||||||
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
||||||
|
config_tokenizer_class = config.tokenizer_class
|
||||||
|
except (OSError, ValueError, KeyError):
|
||||||
|
# skip if an error occured.
|
||||||
|
config = None
|
||||||
|
if config_tokenizer_class is None:
|
||||||
|
# Third attempt. If we have not yet found the original type of the tokenizer,
|
||||||
|
# we are loading we see if we can infer it from the type of the configuration file
|
||||||
|
from .models.auto.configuration_auto import CONFIG_MAPPING
|
||||||
|
from .models.auto.tokenization_auto import TOKENIZER_MAPPING
|
||||||
|
|
||||||
|
if hasattr(config, "model_type"):
|
||||||
|
config_class = CONFIG_MAPPING.get(config.model_type)
|
||||||
|
else:
|
||||||
|
# Fallback: use pattern matching on the string.
|
||||||
|
config_class = None
|
||||||
|
for pattern, config_class_tmp in CONFIG_MAPPING.items():
|
||||||
|
if pattern in str(pretrained_model_name_or_path):
|
||||||
|
config_class = config_class_tmp
|
||||||
|
break
|
||||||
|
|
||||||
|
if config_class in TOKENIZER_MAPPING.keys():
|
||||||
|
config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING[config_class]
|
||||||
|
if config_tokenizer_class is not None:
|
||||||
|
config_tokenizer_class = config_tokenizer_class.__name__
|
||||||
|
else:
|
||||||
|
config_tokenizer_class = config_tokenizer_class_fast.__name__
|
||||||
|
|
||||||
|
if config_tokenizer_class is not None:
|
||||||
|
if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""):
|
||||||
|
logger.warning(
|
||||||
|
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. "
|
||||||
|
"It may result in unexpected tokenization. \n"
|
||||||
|
f"The tokenizer class you load from this checkpoint is '{config_tokenizer_class}'. \n"
|
||||||
|
f"The class this function is called from is '{cls.__name__}'."
|
||||||
|
)
|
||||||
|
|
||||||
# Update with newly provided kwargs
|
# Update with newly provided kwargs
|
||||||
init_kwargs.update(kwargs)
|
init_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from transformers import AutoTokenizer
|
|||||||
from transformers.models.bert_japanese.tokenization_bert_japanese import (
|
from transformers.models.bert_japanese.tokenization_bert_japanese import (
|
||||||
VOCAB_FILES_NAMES,
|
VOCAB_FILES_NAMES,
|
||||||
BertJapaneseTokenizer,
|
BertJapaneseTokenizer,
|
||||||
|
BertTokenizer,
|
||||||
CharacterTokenizer,
|
CharacterTokenizer,
|
||||||
MecabTokenizer,
|
MecabTokenizer,
|
||||||
WordpieceTokenizer,
|
WordpieceTokenizer,
|
||||||
@@ -278,3 +279,23 @@ class AutoTokenizerCustomTest(unittest.TestCase):
|
|||||||
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
|
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
|
tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
|
||||||
self.assertIsInstance(tokenizer, BertJapaneseTokenizer)
|
self.assertIsInstance(tokenizer, BertJapaneseTokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
class BertTokenizerMismatchTest(unittest.TestCase):
|
||||||
|
def test_tokenizer_mismatch_warning(self):
|
||||||
|
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
|
||||||
|
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||||
|
BertTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
|
||||||
|
self.assertTrue(
|
||||||
|
cm.records[0].message.startswith(
|
||||||
|
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
EXAMPLE_BERT_ID = "bert-base-cased"
|
||||||
|
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||||
|
BertJapaneseTokenizer.from_pretrained(EXAMPLE_BERT_ID)
|
||||||
|
self.assertTrue(
|
||||||
|
cm.records[0].message.startswith(
|
||||||
|
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -29,7 +29,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AlbertTokenizer,
|
||||||
|
AlbertTokenizerFast,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
|
BertTokenizerFast,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
PreTrainedTokenizerFast,
|
PreTrainedTokenizerFast,
|
||||||
@@ -3292,6 +3295,41 @@ class TokenizerTesterMixin:
|
|||||||
expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result)
|
expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result)
|
||||||
self.assertEqual(expected_result, decoded_input)
|
self.assertEqual(expected_result, decoded_input)
|
||||||
|
|
||||||
|
def test_tokenizer_mismatch_warning(self):
|
||||||
|
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||||
|
with self.assertLogs("transformers", level="WARNING") as cm:
|
||||||
|
try:
|
||||||
|
if self.tokenizer_class == BertTokenizer:
|
||||||
|
AlbertTokenizer.from_pretrained(pretrained_name)
|
||||||
|
else:
|
||||||
|
BertTokenizer.from_pretrained(pretrained_name)
|
||||||
|
except (TypeError, AttributeError):
|
||||||
|
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
|
||||||
|
# here we just check that the warning has been logged before the error is raised
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self.assertTrue(
|
||||||
|
cm.records[0].message.startswith(
|
||||||
|
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if self.rust_tokenizer_class == BertTokenizerFast:
|
||||||
|
AlbertTokenizerFast.from_pretrained(pretrained_name)
|
||||||
|
else:
|
||||||
|
BertTokenizerFast.from_pretrained(pretrained_name)
|
||||||
|
except (TypeError, AttributeError):
|
||||||
|
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
|
||||||
|
# here we just check that the warning has been logged before the error is raised
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self.assertTrue(
|
||||||
|
cm.records[0].message.startswith(
|
||||||
|
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
class TokenizerPushToHubTester(unittest.TestCase):
|
class TokenizerPushToHubTester(unittest.TestCase):
|
||||||
|
|||||||
@@ -44,6 +44,11 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_paths[0])
|
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_paths[0])
|
||||||
tokenizer.save_pretrained(self.tmpdirname)
|
tokenizer.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def test_tokenizer_mismatch_warning(self):
|
||||||
|
# We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any
|
||||||
|
# model
|
||||||
|
pass
|
||||||
|
|
||||||
def test_pretrained_model_lists(self):
|
def test_pretrained_model_lists(self):
|
||||||
# We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any
|
# We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any
|
||||||
# model
|
# model
|
||||||
|
|||||||
Reference in New Issue
Block a user