From 08d609bfb8fbbaf508ae55c5cf414b262cc04061 Mon Sep 17 00:00:00 2001 From: Tomohiro Endo Date: Sat, 17 Jul 2021 22:52:21 +0900 Subject: [PATCH] Add tokenizers class mismatch detection between `cls` and checkpoint (#12619) * Detect mismatch by analyzing config * Fix comment * Fix import * Update src/transformers/tokenization_utils_base.py Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com> * Revise based on reviews * remove kwargs * Fix exception * Fix handling exception again * Disable mismatch test in PreTrainedTokenizerFast Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com> --- .../tokenization_bert_japanese.py | 2 +- src/transformers/tokenization_utils_base.py | 45 +++++++++++++++++++ tests/test_tokenization_bert_japanese.py | 21 +++++++++ tests/test_tokenization_common.py | 38 ++++++++++++++++ tests/test_tokenization_fast.py | 5 +++ 5 files changed, 110 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/bert_japanese/tokenization_bert_japanese.py b/src/transformers/models/bert_japanese/tokenization_bert_japanese.py index be62e92e05..ecd7df9b03 100644 --- a/src/transformers/models/bert_japanese/tokenization_bert_japanese.py +++ b/src/transformers/models/bert_japanese/tokenization_bert_japanese.py @@ -132,7 +132,7 @@ class BertJapaneseTokenizer(BertTokenizer): if not os.path.isfile(vocab_file): raise ValueError( f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " - "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + "model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" ) self.vocab = load_vocab(vocab_file) self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 59ff00f0f7..5a2cf575ce 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1749,13 +1749,58 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): if tokenizer_config_file is not None: with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: init_kwargs = json.load(tokenizer_config_handle) + # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers. + config_tokenizer_class = init_kwargs.get("tokenizer_class") init_kwargs.pop("tokenizer_class", None) saved_init_inputs = init_kwargs.pop("init_inputs", ()) if not init_inputs: init_inputs = saved_init_inputs else: + config_tokenizer_class = None init_kwargs = init_configuration + if config_tokenizer_class is None: + from .models.auto.configuration_auto import AutoConfig + + # Second attempt. If we have not yet found tokenizer_class, let's try to use the config. + try: + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + config_tokenizer_class = config.tokenizer_class + except (OSError, ValueError, KeyError): + # skip if an error occured. + config = None + if config_tokenizer_class is None: + # Third attempt. If we have not yet found the original type of the tokenizer, + # we are loading we see if we can infer it from the type of the configuration file + from .models.auto.configuration_auto import CONFIG_MAPPING + from .models.auto.tokenization_auto import TOKENIZER_MAPPING + + if hasattr(config, "model_type"): + config_class = CONFIG_MAPPING.get(config.model_type) + else: + # Fallback: use pattern matching on the string. + config_class = None + for pattern, config_class_tmp in CONFIG_MAPPING.items(): + if pattern in str(pretrained_model_name_or_path): + config_class = config_class_tmp + break + + if config_class in TOKENIZER_MAPPING.keys(): + config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING[config_class] + if config_tokenizer_class is not None: + config_tokenizer_class = config_tokenizer_class.__name__ + else: + config_tokenizer_class = config_tokenizer_class_fast.__name__ + + if config_tokenizer_class is not None: + if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""): + logger.warning( + "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. " + "It may result in unexpected tokenization. \n" + f"The tokenizer class you load from this checkpoint is '{config_tokenizer_class}'. \n" + f"The class this function is called from is '{cls.__name__}'." + ) + # Update with newly provided kwargs init_kwargs.update(kwargs) diff --git a/tests/test_tokenization_bert_japanese.py b/tests/test_tokenization_bert_japanese.py index b42a14314a..5994225858 100644 --- a/tests/test_tokenization_bert_japanese.py +++ b/tests/test_tokenization_bert_japanese.py @@ -22,6 +22,7 @@ from transformers import AutoTokenizer from transformers.models.bert_japanese.tokenization_bert_japanese import ( VOCAB_FILES_NAMES, BertJapaneseTokenizer, + BertTokenizer, CharacterTokenizer, MecabTokenizer, WordpieceTokenizer, @@ -278,3 +279,23 @@ class AutoTokenizerCustomTest(unittest.TestCase): EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese" tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID) self.assertIsInstance(tokenizer, BertJapaneseTokenizer) + + +class BertTokenizerMismatchTest(unittest.TestCase): + def test_tokenizer_mismatch_warning(self): + EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese" + with self.assertLogs("transformers", level="WARNING") as cm: + BertTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID) + self.assertTrue( + cm.records[0].message.startswith( + "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from." + ) + ) + EXAMPLE_BERT_ID = "bert-base-cased" + with self.assertLogs("transformers", level="WARNING") as cm: + BertJapaneseTokenizer.from_pretrained(EXAMPLE_BERT_ID) + self.assertTrue( + cm.records[0].message.startswith( + "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from." + ) + ) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index dbc6af764e..7e9dd887ee 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -29,7 +29,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union from huggingface_hub import HfApi from requests.exceptions import HTTPError from transformers import ( + AlbertTokenizer, + AlbertTokenizerFast, BertTokenizer, + BertTokenizerFast, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, @@ -3292,6 +3295,41 @@ class TokenizerTesterMixin: expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result) self.assertEqual(expected_result, decoded_input) + def test_tokenizer_mismatch_warning(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + with self.assertLogs("transformers", level="WARNING") as cm: + try: + if self.tokenizer_class == BertTokenizer: + AlbertTokenizer.from_pretrained(pretrained_name) + else: + BertTokenizer.from_pretrained(pretrained_name) + except (TypeError, AttributeError): + # Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned, + # here we just check that the warning has been logged before the error is raised + pass + finally: + self.assertTrue( + cm.records[0].message.startswith( + "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from." + ) + ) + try: + if self.rust_tokenizer_class == BertTokenizerFast: + AlbertTokenizerFast.from_pretrained(pretrained_name) + else: + BertTokenizerFast.from_pretrained(pretrained_name) + except (TypeError, AttributeError): + # Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned, + # here we just check that the warning has been logged before the error is raised + pass + finally: + self.assertTrue( + cm.records[0].message.startswith( + "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from." + ) + ) + @is_staging_test class TokenizerPushToHubTester(unittest.TestCase): diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py index de237aac18..c6472b0d8d 100644 --- a/tests/test_tokenization_fast.py +++ b/tests/test_tokenization_fast.py @@ -44,6 +44,11 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase): tokenizer = PreTrainedTokenizerFast.from_pretrained(model_paths[0]) tokenizer.save_pretrained(self.tmpdirname) + def test_tokenizer_mismatch_warning(self): + # We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any + # model + pass + def test_pretrained_model_lists(self): # We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any # model