Add tokenizers class mismatch detection between cls and checkpoint (#12619)
* Detect mismatch by analyzing config * Fix comment * Fix import * Update src/transformers/tokenization_utils_base.py Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com> * Revise based on reviews * remove kwargs * Fix exception * Fix handling exception again * Disable mismatch test in PreTrainedTokenizerFast Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>
This commit is contained in:
@@ -1749,13 +1749,58 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
if tokenizer_config_file is not None:
|
||||
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
|
||||
init_kwargs = json.load(tokenizer_config_handle)
|
||||
# First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
|
||||
config_tokenizer_class = init_kwargs.get("tokenizer_class")
|
||||
init_kwargs.pop("tokenizer_class", None)
|
||||
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
||||
if not init_inputs:
|
||||
init_inputs = saved_init_inputs
|
||||
else:
|
||||
config_tokenizer_class = None
|
||||
init_kwargs = init_configuration
|
||||
|
||||
if config_tokenizer_class is None:
|
||||
from .models.auto.configuration_auto import AutoConfig
|
||||
|
||||
# Second attempt. If we have not yet found tokenizer_class, let's try to use the config.
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
||||
config_tokenizer_class = config.tokenizer_class
|
||||
except (OSError, ValueError, KeyError):
|
||||
# skip if an error occured.
|
||||
config = None
|
||||
if config_tokenizer_class is None:
|
||||
# Third attempt. If we have not yet found the original type of the tokenizer,
|
||||
# we are loading we see if we can infer it from the type of the configuration file
|
||||
from .models.auto.configuration_auto import CONFIG_MAPPING
|
||||
from .models.auto.tokenization_auto import TOKENIZER_MAPPING
|
||||
|
||||
if hasattr(config, "model_type"):
|
||||
config_class = CONFIG_MAPPING.get(config.model_type)
|
||||
else:
|
||||
# Fallback: use pattern matching on the string.
|
||||
config_class = None
|
||||
for pattern, config_class_tmp in CONFIG_MAPPING.items():
|
||||
if pattern in str(pretrained_model_name_or_path):
|
||||
config_class = config_class_tmp
|
||||
break
|
||||
|
||||
if config_class in TOKENIZER_MAPPING.keys():
|
||||
config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING[config_class]
|
||||
if config_tokenizer_class is not None:
|
||||
config_tokenizer_class = config_tokenizer_class.__name__
|
||||
else:
|
||||
config_tokenizer_class = config_tokenizer_class_fast.__name__
|
||||
|
||||
if config_tokenizer_class is not None:
|
||||
if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""):
|
||||
logger.warning(
|
||||
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. "
|
||||
"It may result in unexpected tokenization. \n"
|
||||
f"The tokenizer class you load from this checkpoint is '{config_tokenizer_class}'. \n"
|
||||
f"The class this function is called from is '{cls.__name__}'."
|
||||
)
|
||||
|
||||
# Update with newly provided kwargs
|
||||
init_kwargs.update(kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user