diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index 32a0385eca..53bd189f80 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -16,6 +16,7 @@ import logging +from collections import OrderedDict from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig @@ -27,6 +28,7 @@ from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, Open from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig +from .configuration_utils import PretrainedConfig from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig @@ -56,17 +58,38 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( ) -class AutoConfig(object): +CONFIG_MAPPING = OrderedDict( + [ + ("t5", T5Config,), + ("distilbert", DistilBertConfig,), + ("albert", AlbertConfig,), + ("camembert", CamembertConfig,), + ("xlm-roberta", XLMRobertaConfig,), + ("roberta", RobertaConfig,), + ("bert", BertConfig,), + ("openai-gpt", OpenAIGPTConfig,), + ("gpt2", GPT2Config,), + ("transfo-xl", TransfoXLConfig,), + ("xlnet", XLNetConfig,), + ("xlm", XLMConfig,), + ("ctrl", CTRLConfig,), + ] +) + + +class AutoConfig: r""":class:`~transformers.AutoConfig` is a generic configuration class that will be instantiated as one of the configuration classes of the library when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` class method. The `from_pretrained()` method take care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. - The base model class to instantiate is selected as the first pattern matching - in the `pretrained_model_name_or_path` string (in the following order): + When using string matching, the configuration class is matched on + the `pretrained_model_name_or_path` string in the following order: + - contains `t5`: T5Config (T5 model) - contains `distilbert`: DistilBertConfig (DistilBERT model) - contains `albert`: AlbertConfig (ALBERT model) - contains `camembert`: CamembertConfig (CamemBERT model) @@ -90,41 +113,23 @@ class AutoConfig(object): @classmethod def for_model(cls, model_type, *args, **kwargs): - if "distilbert" in model_type: - return DistilBertConfig(*args, **kwargs) - elif "roberta" in model_type: - return RobertaConfig(*args, **kwargs) - elif "bert" in model_type: - return BertConfig(*args, **kwargs) - elif "openai-gpt" in model_type: - return OpenAIGPTConfig(*args, **kwargs) - elif "gpt2" in model_type: - return GPT2Config(*args, **kwargs) - elif "transfo-xl" in model_type: - return TransfoXLConfig(*args, **kwargs) - elif "xlnet" in model_type: - return XLNetConfig(*args, **kwargs) - elif "xlm" in model_type: - return XLMConfig(*args, **kwargs) - elif "ctrl" in model_type: - return CTRLConfig(*args, **kwargs) - elif "albert" in model_type: - return AlbertConfig(*args, **kwargs) - elif "camembert" in model_type: - return CamembertConfig(*args, **kwargs) + for pattern, config_class in CONFIG_MAPPING.items(): + if pattern in model_type: + return config_class(*args, **kwargs) raise ValueError( - "Unrecognized model identifier in {}. Should contains one of " - "'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " - "'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type) + "Unrecognized model identifier in {}. Should contain one of {}".format( + model_type, ", ".join(CONFIG_MAPPING.keys()) + ) ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - r""" Instantiate a one of the configuration classes of the library + r""" Instantiate one of the configuration classes of the library from a pre-trained model configuration. - The configuration class to instantiate is selected as the first pattern matching - in the `pretrained_model_name_or_path` string (in the following order): + The configuration class to instantiate is selected + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. - contains `t5`: T5Config (T5 model) - contains `distilbert`: DistilBertConfig (DistilBERT model) - contains `albert`: AlbertConfig (ALBERT model) @@ -183,36 +188,21 @@ class AutoConfig(object): assert unused_kwargs == {'foo': False} """ - if "t5" in pretrained_model_name_or_path: - return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif "distilbert" in pretrained_model_name_or_path: - return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif "albert" in pretrained_model_name_or_path: - return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif "camembert" in pretrained_model_name_or_path: - return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif "xlm-roberta" in pretrained_model_name_or_path: - return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif "roberta" in pretrained_model_name_or_path: - return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif "bert" in pretrained_model_name_or_path: - return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif "openai-gpt" in pretrained_model_name_or_path: - return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif "gpt2" in pretrained_model_name_or_path: - return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif "transfo-xl" in pretrained_model_name_or_path: - return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif "xlnet" in pretrained_model_name_or_path: - return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif "xlm" in pretrained_model_name_or_path: - return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif "ctrl" in pretrained_model_name_or_path: - return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config_dict, _ = PretrainedConfig.resolved_config_dict( + pretrained_model_name_or_path, pretrained_config_archive_map=ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, **kwargs + ) + + if "model_type" in config_dict: + config_class = CONFIG_MAPPING[config_dict["model_type"]] + return config_class.from_dict(config_dict, **kwargs) + else: + # Fallback: use pattern matching on the string. + for pattern, config_class in CONFIG_MAPPING.items(): + if pattern in pretrained_model_name_or_path: + return config_class.from_dict(config_dict, **kwargs) + raise ValueError( - "Unrecognized model identifier in {}. Should contains one of " - "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " - "'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format( - pretrained_model_name_or_path + "Unrecognized model identifier in {}. Should have a `model_type` key in its config.json, or contain one of {}".format( + pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys()) ) ) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 862a4105a4..ea512a46a9 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -20,6 +20,7 @@ import copy import json import logging import os +from typing import Dict, Optional, Tuple from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url @@ -36,7 +37,7 @@ class PretrainedConfig(object): It only affects the model's configuration. Class attributes (overridden by derived classes): - - ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values. + - ``pretrained_config_archive_map``: a python ``dict`` with `shortcut names` (string) as keys and `url` (string) of associated pretrained model configurations as values. Parameters: ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. @@ -154,14 +155,32 @@ class PretrainedConfig(object): assert unused_kwargs == {'foo': False} """ + config_dict, kwargs = cls.resolved_config_dict(pretrained_model_name_or_path, **kwargs) + return cls.from_dict(config_dict, **kwargs) + + @classmethod + def resolved_config_dict( + cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs + ) -> Tuple[Dict, Dict]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used + for instantiating a Config using `from_dict`. + + Parameters: + pretrained_config_archive_map: (`optional`) Dict: + A map of `shortcut names` to `url`. + By default, will use the current class attribute. + """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) - return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) - if pretrained_model_name_or_path in cls.pretrained_config_archive_map: - config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] + if pretrained_config_archive_map is None: + pretrained_config_archive_map = cls.pretrained_config_archive_map + + if pretrained_model_name_or_path in pretrained_config_archive_map: + config_file = pretrained_config_archive_map[pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): @@ -178,23 +197,20 @@ class PretrainedConfig(object): proxies=proxies, resume_download=resume_download, ) - # Load config - config = cls.from_json_file(resolved_config_file) + # Load config dict + config_dict = cls._dict_from_json_file(resolved_config_file) except EnvironmentError: - if pretrained_model_name_or_path in cls.pretrained_config_archive_map: + if pretrained_model_name_or_path in pretrained_config_archive_map: msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( config_file ) else: msg = ( - "Model name '{}' was not found in model name list ({}). " + "Model name '{}' was not found in model name list. " "We assumed '{}' was a path or url to a configuration file named {} or " "a directory containing such a file but couldn't find any such file at this path or url.".format( - pretrained_model_name_or_path, - ", ".join(cls.pretrained_config_archive_map.keys()), - config_file, - CONFIG_NAME, + pretrained_model_name_or_path, config_file, CONFIG_NAME, ) ) raise EnvironmentError(msg) @@ -212,6 +228,15 @@ class PretrainedConfig(object): else: logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file)) + return config_dict, kwargs + + @classmethod + def from_dict(cls, config_dict: Dict, **kwargs): + """Constructs a `Config` from a Python dictionary of parameters.""" + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + + config = cls(**config_dict) + if hasattr(config, "pruned_heads"): config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) @@ -231,17 +256,16 @@ class PretrainedConfig(object): return config @classmethod - def from_dict(cls, json_object): - """Constructs a `Config` from a Python dictionary of parameters.""" - return cls(**json_object) + def from_json_file(cls, json_file: str): + """Constructs a `Config` from a json file of parameters.""" + config_dict = cls._dict_from_json_file(json_file) + return cls(**config_dict) @classmethod - def from_json_file(cls, json_file): - """Constructs a `Config` from a json file of parameters.""" + def _dict_from_json_file(cls, json_file: str): with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() - dict_obj = json.loads(text) - return cls(**dict_obj) + return json.loads(text) def __eq__(self, other): return self.__dict__ == other.__dict__ diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index d2d18586be..ff134c58dc 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -19,6 +19,7 @@ import logging from .configuration_auto import ( AlbertConfig, + AutoConfig, BertConfig, CamembertConfig, CTRLConfig, @@ -26,11 +27,13 @@ from .configuration_auto import ( GPT2Config, OpenAIGPTConfig, RobertaConfig, + T5Config, TransfoXLConfig, XLMConfig, XLMRobertaConfig, XLNetConfig, ) +from .configuration_utils import PretrainedConfig from .modeling_albert import ( ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, AlbertForMaskedLM, @@ -129,7 +132,8 @@ class AutoModel(object): or the `AutoModel.from_config(config)` class methods. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The base model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -286,32 +290,36 @@ class AutoModel(object): model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ - if "t5" in pretrained_model_name_or_path: - return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "distilbert" in pretrained_model_name_or_path: - return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "albert" in pretrained_model_name_or_path: - return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "camembert" in pretrained_model_name_or_path: - return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlm-roberta" in pretrained_model_name_or_path: - return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "roberta" in pretrained_model_name_or_path: - return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "bert" in pretrained_model_name_or_path: - return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "openai-gpt" in pretrained_model_name_or_path: - return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "gpt2" in pretrained_model_name_or_path: - return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "transfo-xl" in pretrained_model_name_or_path: - return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlnet" in pretrained_model_name_or_path: - return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlm" in pretrained_model_name_or_path: - return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "ctrl" in pretrained_model_name_or_path: - return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + if isinstance(config, T5Config): + return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, DistilBertConfig): + return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, AlbertConfig): + return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, CamembertConfig): + return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, XLMRobertaConfig): + return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, RobertaConfig): + return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, BertConfig): + return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, OpenAIGPTConfig): + return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, GPT2Config): + return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, TransfoXLConfig): + return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, XLNetConfig): + return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, XLMConfig): + return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, CTRLConfig): + return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) raise ValueError( "Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " @@ -329,7 +337,8 @@ class AutoModelWithLMHead(object): class method. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -407,7 +416,8 @@ class AutoModelWithLMHead(object): from a pre-trained model configuration. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -484,32 +494,56 @@ class AutoModelWithLMHead(object): model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ - if "t5" in pretrained_model_name_or_path: - return T5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "distilbert" in pretrained_model_name_or_path: - return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "albert" in pretrained_model_name_or_path: - return AlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "camembert" in pretrained_model_name_or_path: - return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlm-roberta" in pretrained_model_name_or_path: - return XLMRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "roberta" in pretrained_model_name_or_path: - return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "bert" in pretrained_model_name_or_path: - return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "openai-gpt" in pretrained_model_name_or_path: - return OpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "gpt2" in pretrained_model_name_or_path: - return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "transfo-xl" in pretrained_model_name_or_path: - return TransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlnet" in pretrained_model_name_or_path: - return XLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlm" in pretrained_model_name_or_path: - return XLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "ctrl" in pretrained_model_name_or_path: - return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + if isinstance(config, T5Config): + return T5WithLMHeadModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, DistilBertConfig): + return DistilBertForMaskedLM.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, AlbertConfig): + return AlbertForMaskedLM.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, CamembertConfig): + return CamembertForMaskedLM.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, XLMRobertaConfig): + return XLMRobertaForMaskedLM.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, RobertaConfig): + return RobertaForMaskedLM.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, BertConfig): + return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, OpenAIGPTConfig): + return OpenAIGPTLMHeadModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, GPT2Config): + return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, TransfoXLConfig): + return TransfoXLLMHeadModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, XLNetConfig): + return XLNetLMHeadModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, XLMConfig): + return XLMWithLMHeadModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, CTRLConfig): + return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) raise ValueError( "Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " @@ -527,7 +561,8 @@ class AutoModelForSequenceClassification(object): class method. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -592,7 +627,8 @@ class AutoModelForSequenceClassification(object): from a pre-trained model configuration. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -665,32 +701,42 @@ class AutoModelForSequenceClassification(object): model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ - if "distilbert" in pretrained_model_name_or_path: + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + if isinstance(config, DistilBertConfig): return DistilBertForSequenceClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "albert" in pretrained_model_name_or_path: + elif isinstance(config, AlbertConfig): return AlbertForSequenceClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "camembert" in pretrained_model_name_or_path: + elif isinstance(config, CamembertConfig): return CamembertForSequenceClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "xlm-roberta" in pretrained_model_name_or_path: + elif isinstance(config, XLMRobertaConfig): return XLMRobertaForSequenceClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "roberta" in pretrained_model_name_or_path: + elif isinstance(config, RobertaConfig): return RobertaForSequenceClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, BertConfig): + return BertForSequenceClassification.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, XLNetConfig): + return XLNetForSequenceClassification.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, XLMConfig): + return XLMForSequenceClassification.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "bert" in pretrained_model_name_or_path: - return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlnet" in pretrained_model_name_or_path: - return XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlm" in pretrained_model_name_or_path: - return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) raise ValueError( "Unrecognized model identifier in {}. Should contains one of " @@ -708,7 +754,8 @@ class AutoModelForQuestionAnswering(object): class method. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -763,7 +810,8 @@ class AutoModelForQuestionAnswering(object): from a pre-trained model configuration. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -830,16 +878,30 @@ class AutoModelForQuestionAnswering(object): model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ - if "distilbert" in pretrained_model_name_or_path: - return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "albert" in pretrained_model_name_or_path: - return AlbertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "bert" in pretrained_model_name_or_path: - return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlnet" in pretrained_model_name_or_path: - return XLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlm" in pretrained_model_name_or_path: - return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + if isinstance(config, DistilBertConfig): + return DistilBertForQuestionAnswering.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, AlbertConfig): + return AlbertForQuestionAnswering.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, BertConfig): + return BertForQuestionAnswering.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, XLNetConfig): + return XLNetForQuestionAnswering.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, XLMConfig): + return XLMForQuestionAnswering.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) raise ValueError( "Unrecognized model identifier in {}. Should contains one of " @@ -893,7 +955,8 @@ class AutoModelForTokenClassification: from a pre-trained model configuration. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -959,24 +1022,34 @@ class AutoModelForTokenClassification: model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ - if "camembert" in pretrained_model_name_or_path: + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + if isinstance(config, CamembertConfig): return CamembertForTokenClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "distilbert" in pretrained_model_name_or_path: + elif isinstance(config, DistilBertConfig): return DistilBertForTokenClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "xlm-roberta" in pretrained_model_name_or_path: + elif isinstance(config, XLMRobertaConfig): return XLMRobertaForTokenClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, RobertaConfig): + return RobertaForTokenClassification.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, BertConfig): + return BertForTokenClassification.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, XLNetConfig): + return XLNetForTokenClassification.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "roberta" in pretrained_model_name_or_path: - return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "bert" in pretrained_model_name_or_path: - return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlnet" in pretrained_model_name_or_path: - return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) raise ValueError( "Unrecognized model identifier in {}. Should contains one of " diff --git a/src/transformers/modeling_tf_auto.py b/src/transformers/modeling_tf_auto.py index 5463a73668..b43a9ac514 100644 --- a/src/transformers/modeling_tf_auto.py +++ b/src/transformers/modeling_tf_auto.py @@ -18,16 +18,20 @@ import logging from .configuration_auto import ( + AlbertConfig, + AutoConfig, BertConfig, CTRLConfig, DistilBertConfig, GPT2Config, OpenAIGPTConfig, RobertaConfig, + T5Config, TransfoXLConfig, XLMConfig, XLNetConfig, ) +from .configuration_utils import PretrainedConfig from .modeling_tf_albert import ( TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, TFAlbertForMaskedLM, @@ -113,7 +117,8 @@ class TFAutoModel(object): class method. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The base model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -257,28 +262,38 @@ class TFAutoModel(object): model = TFAutoModel.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) """ - if "t5" in pretrained_model_name_or_path: - return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "distilbert" in pretrained_model_name_or_path: - return TFDistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "albert" in pretrained_model_name_or_path: - return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "roberta" in pretrained_model_name_or_path: - return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "bert" in pretrained_model_name_or_path: - return TFBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "openai-gpt" in pretrained_model_name_or_path: - return TFOpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "gpt2" in pretrained_model_name_or_path: - return TFGPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "transfo-xl" in pretrained_model_name_or_path: - return TFTransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlnet" in pretrained_model_name_or_path: - return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlm" in pretrained_model_name_or_path: - return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "ctrl" in pretrained_model_name_or_path: - return TFCTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + if isinstance(config, T5Config): + return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, DistilBertConfig): + return TFDistilBertModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, AlbertConfig): + return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, RobertaConfig): + return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, BertConfig): + return TFBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, OpenAIGPTConfig): + return TFOpenAIGPTModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, GPT2Config): + return TFGPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, TransfoXLConfig): + return TFTransfoXLModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, XLNetConfig): + return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, XLMConfig): + return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif isinstance(config, CTRLConfig): + return TFCTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) raise ValueError( "Unrecognized model identifier in {}. Should contains one of " @@ -295,7 +310,8 @@ class TFAutoModelWithLMHead(object): class method. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -368,7 +384,8 @@ class TFAutoModelWithLMHead(object): from a pre-trained model configuration. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -443,28 +460,54 @@ class TFAutoModelWithLMHead(object): model = TFAutoModelWithLMHead.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) """ - if "t5" in pretrained_model_name_or_path: - return TFT5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "distilbert" in pretrained_model_name_or_path: - return TFDistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "albert" in pretrained_model_name_or_path: - return TFAlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "roberta" in pretrained_model_name_or_path: - return TFRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "bert" in pretrained_model_name_or_path: - return TFBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "openai-gpt" in pretrained_model_name_or_path: - return TFOpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "gpt2" in pretrained_model_name_or_path: - return TFGPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "transfo-xl" in pretrained_model_name_or_path: - return TFTransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlnet" in pretrained_model_name_or_path: - return TFXLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlm" in pretrained_model_name_or_path: - return TFXLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "ctrl" in pretrained_model_name_or_path: - return TFCTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + if isinstance(config, T5Config): + return TFT5WithLMHeadModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, DistilBertConfig): + return TFDistilBertForMaskedLM.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, AlbertConfig): + return TFAlbertForMaskedLM.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, RobertaConfig): + return TFRobertaForMaskedLM.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, BertConfig): + return TFBertForMaskedLM.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, OpenAIGPTConfig): + return TFOpenAIGPTLMHeadModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, GPT2Config): + return TFGPT2LMHeadModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, TransfoXLConfig): + return TFTransfoXLLMHeadModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, XLNetConfig): + return TFXLNetLMHeadModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, XLMConfig): + return TFXLMWithLMHeadModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, CTRLConfig): + return TFCTRLLMHeadModel.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) raise ValueError( "Unrecognized model identifier in {}. Should contains one of " @@ -481,7 +524,8 @@ class TFAutoModelForSequenceClassification(object): class method. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -537,7 +581,8 @@ class TFAutoModelForSequenceClassification(object): from a pre-trained model configuration. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -610,28 +655,34 @@ class TFAutoModelForSequenceClassification(object): model = TFAutoModelForSequenceClassification.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) """ - if "distilbert" in pretrained_model_name_or_path: + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + if isinstance(config, DistilBertConfig): return TFDistilBertForSequenceClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "albert" in pretrained_model_name_or_path: + elif isinstance(config, AlbertConfig): return TFAlbertForSequenceClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "roberta" in pretrained_model_name_or_path: + elif isinstance(config, RobertaConfig): return TFRobertaForSequenceClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "bert" in pretrained_model_name_or_path: + elif isinstance(config, BertConfig): return TFBertForSequenceClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "xlnet" in pretrained_model_name_or_path: + elif isinstance(config, XLNetConfig): return TFXLNetForSequenceClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, XLMConfig): + return TFXLMForSequenceClassification.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "xlm" in pretrained_model_name_or_path: - return TFXLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) raise ValueError( "Unrecognized model identifier in {}. Should contains one of " @@ -647,7 +698,8 @@ class TFAutoModelForQuestionAnswering(object): class method. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -699,7 +751,8 @@ class TFAutoModelForQuestionAnswering(object): from a pre-trained model configuration. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -771,19 +824,25 @@ class TFAutoModelForQuestionAnswering(object): model = TFAutoModelForQuestionAnswering.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) """ - if "distilbert" in pretrained_model_name_or_path: + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + if isinstance(config, DistilBertConfig): return TFDistilBertForQuestionAnswering.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "bert" in pretrained_model_name_or_path: - return TFBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlnet" in pretrained_model_name_or_path: + elif isinstance(config, BertConfig): + return TFBertForQuestionAnswering.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, XLNetConfig): return TFXLNetForQuestionAnsweringSimple.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "xlm" in pretrained_model_name_or_path: + elif isinstance(config, XLMConfig): return TFXLMForQuestionAnsweringSimple.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) raise ValueError( @@ -833,7 +892,8 @@ class TFAutoModelForTokenClassification: from a pre-trained model configuration. The `from_pretrained()` method takes care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -898,17 +958,25 @@ class TFAutoModelForTokenClassification: model = TFAutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ - if "bert" in pretrained_model_name_or_path: - return TFBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "xlnet" in pretrained_model_name_or_path: - return TFXLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif "distilbert" in pretrained_model_name_or_path: - return TFDistilBertForTokenClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + if isinstance(config, BertConfig): + return TFBertForTokenClassification.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) - elif "roberta" in pretrained_model_name_or_path: + elif isinstance(config, XLNetConfig): + return TFXLNetForTokenClassification.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, DistilBertConfig): + return TFDistilBertForTokenClassification.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + elif isinstance(config, RobertaConfig): return TFRobertaForTokenClassification.from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, *model_args, config=config, **kwargs ) raise ValueError( diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index e01f017733..d054924e81 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -17,6 +17,23 @@ import logging +from .configuration_auto import ( + AlbertConfig, + AutoConfig, + BertConfig, + CamembertConfig, + CTRLConfig, + DistilBertConfig, + GPT2Config, + OpenAIGPTConfig, + RobertaConfig, + T5Config, + TransfoXLConfig, + XLMConfig, + XLMRobertaConfig, + XLNetConfig, +) +from .configuration_utils import PretrainedConfig from .tokenization_albert import AlbertTokenizer from .tokenization_bert import BertTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer @@ -43,7 +60,8 @@ class AutoTokenizer(object): class method. The `from_pretrained()` method take care of returning the correct tokenizer class instance - using pattern matching on the `pretrained_model_name_or_path` string. + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. The tokenizer class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): @@ -72,7 +90,7 @@ class AutoTokenizer(object): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): - r""" Instantiate a one of the tokenizer classes of the library + r""" Instantiate one of the tokenizer classes of the library from a pre-trained model vocabulary. The tokenizer class to instantiate is selected as the first pattern matching @@ -129,33 +147,38 @@ class AutoTokenizer(object): tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') """ - if "t5" in pretrained_model_name_or_path: - return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif "distilbert" in pretrained_model_name_or_path: - return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif "albert" in pretrained_model_name_or_path: - return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif "camembert" in pretrained_model_name_or_path: - return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif "xlm-roberta" in pretrained_model_name_or_path: - return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif "roberta" in pretrained_model_name_or_path: - return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif "bert-base-japanese" in pretrained_model_name_or_path: + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + if "bert-base-japanese" in pretrained_model_name_or_path: return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif "bert" in pretrained_model_name_or_path: + + if isinstance(config, T5Config): + return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + elif isinstance(config, DistilBertConfig): + return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + elif isinstance(config, AlbertConfig): + return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + elif isinstance(config, CamembertConfig): + return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + elif isinstance(config, XLMRobertaConfig): + return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + elif isinstance(config, RobertaConfig): + return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + elif isinstance(config, BertConfig): return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif "openai-gpt" in pretrained_model_name_or_path: + elif isinstance(config, OpenAIGPTConfig): return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif "gpt2" in pretrained_model_name_or_path: + elif isinstance(config, GPT2Config): return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif "transfo-xl" in pretrained_model_name_or_path: + elif isinstance(config, TransfoXLConfig): return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif "xlnet" in pretrained_model_name_or_path: + elif isinstance(config, XLNetConfig): return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif "xlm" in pretrained_model_name_or_path: + elif isinstance(config, XLMConfig): return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif "ctrl" in pretrained_model_name_or_path: + elif isinstance(config, CTRLConfig): return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) raise ValueError( "Unrecognized model identifier in {}. Should contains one of " diff --git a/tests/fixtures/dummy-config.json b/tests/fixtures/dummy-config.json new file mode 100644 index 0000000000..e388bdf711 --- /dev/null +++ b/tests/fixtures/dummy-config.json @@ -0,0 +1,3 @@ +{ + "model_type": "roberta" +} \ No newline at end of file diff --git a/tests/test_configuration_auto.py b/tests/test_configuration_auto.py new file mode 100644 index 0000000000..842732da46 --- /dev/null +++ b/tests/test_configuration_auto.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from transformers.configuration_auto import AutoConfig +from transformers.configuration_bert import BertConfig +from transformers.configuration_roberta import RobertaConfig + + +SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json") + + +class AutoConfigTest(unittest.TestCase): + def test_config_from_model_shortcut(self): + config = AutoConfig.from_pretrained("bert-base-uncased") + self.assertIsInstance(config, BertConfig) + + def test_config_from_model_type(self): + config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG) + self.assertIsInstance(config, RobertaConfig) + + def test_config_for_model_str(self): + config = AutoConfig.for_model("roberta") + self.assertIsInstance(config, RobertaConfig) diff --git a/tests/test_modeling_auto.py b/tests/test_modeling_auto.py index c91e3cd3ed..dcf2526577 100644 --- a/tests/test_modeling_auto.py +++ b/tests/test_modeling_auto.py @@ -83,7 +83,7 @@ class AutoModelTest(unittest.TestCase): self.assertIsNotNone(model) self.assertIsInstance(model, BertForSequenceClassification) - @slow + # @slow def test_question_answering_model_from_pretrained(self): logging.basicConfig(level=logging.INFO) for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index c2da723543..7ee544eea7 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -29,7 +29,7 @@ from .utils import SMALL_MODEL_IDENTIFIER, slow class AutoTokenizerTest(unittest.TestCase): - @slow + # @slow def test_tokenizer_from_pretrained(self): logging.basicConfig(level=logging.INFO) for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: