AutoConfig + other Auto classes honor model_type

This commit is contained in:
Julien Chaumond
2020-01-11 02:46:17 +00:00
parent 2f32dfd33b
commit 4d1c98c012
9 changed files with 503 additions and 284 deletions

View File

@@ -17,6 +17,23 @@
import logging
from .configuration_auto import (
AlbertConfig,
AutoConfig,
BertConfig,
CamembertConfig,
CTRLConfig,
DistilBertConfig,
GPT2Config,
OpenAIGPTConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
)
from .configuration_utils import PretrainedConfig
from .tokenization_albert import AlbertTokenizer
from .tokenization_bert import BertTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer
@@ -43,7 +60,8 @@ class AutoTokenizer(object):
class method.
The `from_pretrained()` method take care of returning the correct tokenizer class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@@ -72,7 +90,7 @@ class AutoTokenizer(object):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
r""" Instantiate a one of the tokenizer classes of the library
r""" Instantiate one of the tokenizer classes of the library
from a pre-trained model vocabulary.
The tokenizer class to instantiate is selected as the first pattern matching
@@ -129,33 +147,38 @@ class AutoTokenizer(object):
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')
"""
if "t5" in pretrained_model_name_or_path:
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "distilbert" in pretrained_model_name_or_path:
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "albert" in pretrained_model_name_or_path:
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "camembert" in pretrained_model_name_or_path:
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "roberta" in pretrained_model_name_or_path:
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "bert-base-japanese" in pretrained_model_name_or_path:
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if "bert-base-japanese" in pretrained_model_name_or_path:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "bert" in pretrained_model_name_or_path:
if isinstance(config, T5Config):
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, DistilBertConfig):
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, AlbertConfig):
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, CamembertConfig):
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, RobertaConfig):
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, BertConfig):
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "openai-gpt" in pretrained_model_name_or_path:
elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "gpt2" in pretrained_model_name_or_path:
elif isinstance(config, GPT2Config):
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "transfo-xl" in pretrained_model_name_or_path:
elif isinstance(config, TransfoXLConfig):
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
elif isinstance(config, XLNetConfig):
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "xlm" in pretrained_model_name_or_path:
elif isinstance(config, XLMConfig):
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "ctrl" in pretrained_model_name_or_path:
elif isinstance(config, CTRLConfig):
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "