Map configs to models and tokenizers

This commit is contained in:
Julien Chaumond
2020-01-13 23:11:44 +00:00
parent 1fc855e456
commit 0304628590
6 changed files with 322 additions and 458 deletions

View File

@@ -16,6 +16,8 @@
import logging
from collections import OrderedDict
from typing import Dict, Type
from .configuration_auto import (
AlbertConfig,
@@ -45,6 +47,7 @@ from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_roberta import RobertaTokenizer
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import XLNetTokenizer
@@ -53,6 +56,25 @@ from .tokenization_xlnet import XLNetTokenizer
logger = logging.getLogger(__name__)
TOKENIZER_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedTokenizer]] = OrderedDict(
[
(T5Config, T5Tokenizer),
(DistilBertConfig, DistilBertTokenizer),
(AlbertConfig, AlbertTokenizer),
(CamembertConfig, CamembertTokenizer),
(RobertaConfig, XLMRobertaTokenizer),
(XLMRobertaConfig, RobertaTokenizer),
(BertConfig, BertTokenizer),
(OpenAIGPTConfig, OpenAIGPTTokenizer),
(GPT2Config, GPT2Tokenizer),
(TransfoXLConfig, TransfoXLTokenizer),
(XLNetConfig, XLNetTokenizer),
(XLMConfig, XLMTokenizer),
(CTRLConfig, CTRLTokenizer),
]
)
class AutoTokenizer(object):
r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library
@@ -154,36 +176,13 @@ class AutoTokenizer(object):
if "bert-base-japanese" in pretrained_model_name_or_path:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
if isinstance(config, T5Config):
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, DistilBertConfig):
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, AlbertConfig):
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, CamembertConfig):
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, RobertaConfig):
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, BertConfig):
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, GPT2Config):
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, TransfoXLConfig):
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLNetConfig):
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLMConfig):
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, CTRLConfig):
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
for config_class, tokenizer_class in TOKENIZER_MAPPING.items():
if isinstance(config, config_class):
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(
pretrained_model_name_or_path
"Unrecognized configuration class {} to build an AutoTokenizer.\n"
"Model type should be one of {}.".format(
config.__class__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
)
)