[from_pretrained] Allow tokenizer_type ≠ model_type (#6995)
This commit is contained in:
@@ -190,6 +190,7 @@ class PretrainedConfig(object):
|
||||
self.num_labels = kwargs.pop("num_labels", 2)
|
||||
|
||||
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
|
||||
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
|
||||
self.prefix = kwargs.pop("prefix", None)
|
||||
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||
|
||||
@@ -15,6 +15,7 @@ from .file_utils import _tf_available, _torch_available, _torch_tpu_available
|
||||
|
||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
|
||||
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
|
||||
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
|
||||
|
||||
|
||||
|
||||
@@ -222,6 +222,17 @@ class AutoTokenizer:
|
||||
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
use_fast = kwargs.pop("use_fast", False)
|
||||
|
||||
if config.tokenizer_class is not None:
|
||||
if use_fast and not config.tokenizer_class.endswith("Fast"):
|
||||
tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
|
||||
else:
|
||||
tokenizer_class_candidate = config.tokenizer_class
|
||||
tokenizer_class = globals().get(tokenizer_class_candidate)
|
||||
if tokenizer_class is None:
|
||||
raise ValueError("Tokenizer class {} does not exist or is not currently imported.")
|
||||
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
for config_class, (tokenizer_class_py, tokenizer_class_fast) in TOKENIZER_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
if tokenizer_class_fast and use_fast:
|
||||
|
||||
Reference in New Issue
Block a user