Map configs to models and tokenizers
This commit is contained in:
@@ -16,6 +16,8 @@
|
||||
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Type
|
||||
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
@@ -45,6 +47,7 @@ from .tokenization_openai import OpenAIGPTTokenizer
|
||||
from .tokenization_roberta import RobertaTokenizer
|
||||
from .tokenization_t5 import T5Tokenizer
|
||||
from .tokenization_transfo_xl import TransfoXLTokenizer
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_xlm import XLMTokenizer
|
||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
from .tokenization_xlnet import XLNetTokenizer
|
||||
@@ -53,6 +56,25 @@ from .tokenization_xlnet import XLNetTokenizer
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TOKENIZER_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedTokenizer]] = OrderedDict(
|
||||
[
|
||||
(T5Config, T5Tokenizer),
|
||||
(DistilBertConfig, DistilBertTokenizer),
|
||||
(AlbertConfig, AlbertTokenizer),
|
||||
(CamembertConfig, CamembertTokenizer),
|
||||
(RobertaConfig, XLMRobertaTokenizer),
|
||||
(XLMRobertaConfig, RobertaTokenizer),
|
||||
(BertConfig, BertTokenizer),
|
||||
(OpenAIGPTConfig, OpenAIGPTTokenizer),
|
||||
(GPT2Config, GPT2Tokenizer),
|
||||
(TransfoXLConfig, TransfoXLTokenizer),
|
||||
(XLNetConfig, XLNetTokenizer),
|
||||
(XLMConfig, XLMTokenizer),
|
||||
(CTRLConfig, CTRLTokenizer),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class AutoTokenizer(object):
|
||||
r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
|
||||
that will be instantiated as one of the tokenizer classes of the library
|
||||
@@ -154,36 +176,13 @@ class AutoTokenizer(object):
|
||||
if "bert-base-japanese" in pretrained_model_name_or_path:
|
||||
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
if isinstance(config, T5Config):
|
||||
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, CamembertConfig):
|
||||
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, GPT2Config):
|
||||
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
for config_class, tokenizer_class in TOKENIZER_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(
|
||||
pretrained_model_name_or_path
|
||||
"Unrecognized configuration class {} to build an AutoTokenizer.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user