AutoConfig + other Auto classes honor model_type
This commit is contained in:
@@ -17,6 +17,23 @@
|
||||
|
||||
import logging
|
||||
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
AutoConfig,
|
||||
BertConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DistilBertConfig,
|
||||
GPT2Config,
|
||||
OpenAIGPTConfig,
|
||||
RobertaConfig,
|
||||
T5Config,
|
||||
TransfoXLConfig,
|
||||
XLMConfig,
|
||||
XLMRobertaConfig,
|
||||
XLNetConfig,
|
||||
)
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_bert import BertTokenizer
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer
|
||||
@@ -43,7 +60,8 @@ class AutoTokenizer(object):
|
||||
class method.
|
||||
|
||||
The `from_pretrained()` method take care of returning the correct tokenizer class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The tokenizer class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@@ -72,7 +90,7 @@ class AutoTokenizer(object):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
r""" Instantiate a one of the tokenizer classes of the library
|
||||
r""" Instantiate one of the tokenizer classes of the library
|
||||
from a pre-trained model vocabulary.
|
||||
|
||||
The tokenizer class to instantiate is selected as the first pattern matching
|
||||
@@ -129,33 +147,38 @@ class AutoTokenizer(object):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')
|
||||
|
||||
"""
|
||||
if "t5" in pretrained_model_name_or_path:
|
||||
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "distilbert" in pretrained_model_name_or_path:
|
||||
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "camembert" in pretrained_model_name_or_path:
|
||||
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "bert-base-japanese" in pretrained_model_name_or_path:
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if "bert-base-japanese" in pretrained_model_name_or_path:
|
||||
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
|
||||
if isinstance(config, T5Config):
|
||||
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, CamembertConfig):
|
||||
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "openai-gpt" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "gpt2" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, GPT2Config):
|
||||
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "transfo-xl" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "ctrl" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
|
||||
Reference in New Issue
Block a user