[from_pretrained] Allow tokenizer_type ≠ model_type (#6995)

This commit is contained in:
Julien Chaumond
2020-09-09 10:22:59 +02:00
committed by GitHub
parent 03e363f9ae
commit ed71c21d6a
4 changed files with 28 additions and 1 deletions

View File

@@ -190,6 +190,7 @@ class PretrainedConfig(object):
self.num_labels = kwargs.pop("num_labels", 2)
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
self.prefix = kwargs.pop("prefix", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)

View File

@@ -15,6 +15,7 @@ from .file_utils import _tf_available, _torch_available, _torch_tpu_available
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
# Used to test Auto{Config, Model, Tokenizer} model_type detection.

View File

@@ -222,6 +222,17 @@ class AutoTokenizer:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
use_fast = kwargs.pop("use_fast", False)
if config.tokenizer_class is not None:
if use_fast and not config.tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
else:
tokenizer_class_candidate = config.tokenizer_class
tokenizer_class = globals().get(tokenizer_class_candidate)
if tokenizer_class is None:
raise ValueError("Tokenizer class {} does not exist or is not currently imported.")
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
for config_class, (tokenizer_class_py, tokenizer_class_fast) in TOKENIZER_MAPPING.items():
if isinstance(config, config_class):
if tokenizer_class_fast and use_fast: