From b803b067bfb7740a7b9ae6ae149fbbd24e3f21af Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 13 Jan 2020 20:05:20 +0000 Subject: [PATCH] Config to Model mapping --- .../summarization/configuration_bertabs.py | 1 + src/transformers/configuration_albert.py | 1 + src/transformers/configuration_auto.py | 2 +- src/transformers/configuration_bert.py | 1 + src/transformers/configuration_camembert.py | 1 + src/transformers/configuration_ctrl.py | 1 + src/transformers/configuration_distilbert.py | 1 + src/transformers/configuration_gpt2.py | 1 + src/transformers/configuration_openai.py | 1 + src/transformers/configuration_t5.py | 1 + src/transformers/configuration_transfo_xl.py | 1 + src/transformers/configuration_utils.py | 9 +- src/transformers/configuration_xlm.py | 1 + src/transformers/configuration_xlm_roberta.py | 1 + src/transformers/configuration_xlnet.py | 1 + src/transformers/modeling_auto.py | 88 ++++++++----------- .../adding_a_new_model/configuration_xxx.py | 1 + 17 files changed, 58 insertions(+), 55 deletions(-) diff --git a/examples/summarization/configuration_bertabs.py b/examples/summarization/configuration_bertabs.py index 530fb61074..aa51d63980 100644 --- a/examples/summarization/configuration_bertabs.py +++ b/examples/summarization/configuration_bertabs.py @@ -62,6 +62,7 @@ class BertAbsConfig(PretrainedConfig): """ pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP + model_type = "bertabs" def __init__( self, diff --git a/src/transformers/configuration_albert.py b/src/transformers/configuration_albert.py index 1d6adfa7e9..2ac969538b 100644 --- a/src/transformers/configuration_albert.py +++ b/src/transformers/configuration_albert.py @@ -37,6 +37,7 @@ class AlbertConfig(PretrainedConfig): """ pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "albert" def __init__( self, diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index 53bd189f80..904fa9f9ec 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -188,7 +188,7 @@ class AutoConfig: assert unused_kwargs == {'foo': False} """ - config_dict, _ = PretrainedConfig.resolved_config_dict( + config_dict, _ = PretrainedConfig.get_config_dict( pretrained_model_name_or_path, pretrained_config_archive_map=ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, **kwargs ) diff --git a/src/transformers/configuration_bert.py b/src/transformers/configuration_bert.py index 32fa50a504..762a8da9fd 100644 --- a/src/transformers/configuration_bert.py +++ b/src/transformers/configuration_bert.py @@ -78,6 +78,7 @@ class BertConfig(PretrainedConfig): layer_norm_eps: The epsilon used by LayerNorm. """ pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "bert" def __init__( self, diff --git a/src/transformers/configuration_camembert.py b/src/transformers/configuration_camembert.py index 8ecdf714b1..a4263556aa 100644 --- a/src/transformers/configuration_camembert.py +++ b/src/transformers/configuration_camembert.py @@ -30,3 +30,4 @@ CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { class CamembertConfig(RobertaConfig): pretrained_config_archive_map = CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "camembert" diff --git a/src/transformers/configuration_ctrl.py b/src/transformers/configuration_ctrl.py index e23bf7a376..9becc78754 100644 --- a/src/transformers/configuration_ctrl.py +++ b/src/transformers/configuration_ctrl.py @@ -48,6 +48,7 @@ class CTRLConfig(PretrainedConfig): """ pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "ctrl" def __init__( self, diff --git a/src/transformers/configuration_distilbert.py b/src/transformers/configuration_distilbert.py index 1dd4a11912..b86a9f7fa8 100644 --- a/src/transformers/configuration_distilbert.py +++ b/src/transformers/configuration_distilbert.py @@ -32,6 +32,7 @@ DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { class DistilBertConfig(PretrainedConfig): pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "distilbert" def __init__( self, diff --git a/src/transformers/configuration_gpt2.py b/src/transformers/configuration_gpt2.py index 8da1800747..c91d01b139 100644 --- a/src/transformers/configuration_gpt2.py +++ b/src/transformers/configuration_gpt2.py @@ -54,6 +54,7 @@ class GPT2Config(PretrainedConfig): """ pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "gpt2" def __init__( self, diff --git a/src/transformers/configuration_openai.py b/src/transformers/configuration_openai.py index d7e88bda92..b2ad81eb02 100644 --- a/src/transformers/configuration_openai.py +++ b/src/transformers/configuration_openai.py @@ -54,6 +54,7 @@ class OpenAIGPTConfig(PretrainedConfig): """ pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "openai-gpt" def __init__( self, diff --git a/src/transformers/configuration_t5.py b/src/transformers/configuration_t5.py index deef206d01..c7016baed5 100644 --- a/src/transformers/configuration_t5.py +++ b/src/transformers/configuration_t5.py @@ -60,6 +60,7 @@ class T5Config(PretrainedConfig): layer_norm_eps: The epsilon used by LayerNorm. """ pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "t5" def __init__( self, diff --git a/src/transformers/configuration_transfo_xl.py b/src/transformers/configuration_transfo_xl.py index 7b285ca3ed..9e332fa8c3 100644 --- a/src/transformers/configuration_transfo_xl.py +++ b/src/transformers/configuration_transfo_xl.py @@ -65,6 +65,7 @@ class TransfoXLConfig(PretrainedConfig): """ pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "transfo-xl" def __init__( self, diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index d0865ee179..d87547ea22 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -46,7 +46,8 @@ class PretrainedConfig(object): ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states. ``torchscript``: string, default `False`. Is the model used with Torchscript. """ - pretrained_config_archive_map = {} + pretrained_config_archive_map: Dict[str, str] = {} + model_type: str def __init__(self, **kwargs): # Attributes with defaults @@ -155,11 +156,11 @@ class PretrainedConfig(object): assert unused_kwargs == {'foo': False} """ - config_dict, kwargs = cls.resolved_config_dict(pretrained_model_name_or_path, **kwargs) + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) return cls.from_dict(config_dict, **kwargs) @classmethod - def resolved_config_dict( + def get_config_dict( cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs ) -> Tuple[Dict, Dict]: """ @@ -257,7 +258,7 @@ class PretrainedConfig(object): @classmethod def from_json_file(cls, json_file: str): - """Constructs a `Config` from a json file of parameters.""" + """Constructs a `Config` from the path to a json file of parameters.""" config_dict = cls._dict_from_json_file(json_file) return cls(**config_dict) diff --git a/src/transformers/configuration_xlm.py b/src/transformers/configuration_xlm.py index b56182413b..c3bf64c724 100644 --- a/src/transformers/configuration_xlm.py +++ b/src/transformers/configuration_xlm.py @@ -78,6 +78,7 @@ class XLMConfig(PretrainedConfig): """ pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "xlm" def __init__( self, diff --git a/src/transformers/configuration_xlm_roberta.py b/src/transformers/configuration_xlm_roberta.py index a9cdf7c160..0208a0449c 100644 --- a/src/transformers/configuration_xlm_roberta.py +++ b/src/transformers/configuration_xlm_roberta.py @@ -35,3 +35,4 @@ XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { class XLMRobertaConfig(RobertaConfig): pretrained_config_archive_map = XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "xlm-roberta" diff --git a/src/transformers/configuration_xlnet.py b/src/transformers/configuration_xlnet.py index 38d00d7604..2e4536fa0d 100644 --- a/src/transformers/configuration_xlnet.py +++ b/src/transformers/configuration_xlnet.py @@ -69,6 +69,7 @@ class XLNetConfig(PretrainedConfig): """ pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "xlnet" def __init__( self, diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index ff134c58dc..5d65bef623 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -16,6 +16,8 @@ import logging +from collections import OrderedDict +from typing import Type from .configuration_auto import ( AlbertConfig, @@ -76,6 +78,7 @@ from .modeling_roberta import ( ) from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5Model, T5WithLMHeadModel from .modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TransfoXLModel +from .modeling_utils import PreTrainedModel from .modeling_xlm import ( XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLMForQuestionAnswering, @@ -123,6 +126,35 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict( for key, value, in pretrained_map.items() ) +MODEL_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict( + [ + (T5Config, T5Model), + (DistilBertConfig, DistilBertModel), + (AlbertConfig, AlbertModel), + (CamembertConfig, CamembertModel), + (RobertaConfig, XLMRobertaModel), + (XLMRobertaConfig, RobertaModel), + (BertConfig, BertModel), + (OpenAIGPTConfig, OpenAIGPTModel), + (GPT2Config, GPT2Model), + (TransfoXLConfig, TransfoXLModel), + (XLNetConfig, XLNetModel), + (XLMConfig, XLMModel), + (CTRLConfig, CTRLModel), + ] +) + +MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict( + [ + (DistilBertConfig, DistilBertForTokenClassification), + (CamembertConfig, CamembertForTokenClassification), + (RobertaConfig, XLMRobertaForTokenClassification), + (XLMRobertaConfig, RobertaForTokenClassification), + (BertConfig, BertForTokenClassification), + (XLNetConfig, XLNetForTokenClassification), + ] +) + class AutoModel(object): r""" @@ -183,30 +215,9 @@ class AutoModel(object): config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. model = AutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` """ - if isinstance(config, DistilBertConfig): - return DistilBertModel(config) - elif isinstance(config, RobertaConfig): - return RobertaModel(config) - elif isinstance(config, BertConfig): - return BertModel(config) - elif isinstance(config, OpenAIGPTConfig): - return OpenAIGPTModel(config) - elif isinstance(config, GPT2Config): - return GPT2Model(config) - elif isinstance(config, TransfoXLConfig): - return TransfoXLModel(config) - elif isinstance(config, XLNetConfig): - return XLNetModel(config) - elif isinstance(config, XLMConfig): - return XLMModel(config) - elif isinstance(config, CTRLConfig): - return CTRLModel(config) - elif isinstance(config, AlbertConfig): - return AlbertModel(config) - elif isinstance(config, CamembertConfig): - return CamembertModel(config) - elif isinstance(config, XLMRobertaConfig): - return XLMRobertaModel(config) + for config_class, model_class in MODEL_MAPPING.items(): + if isinstance(config, config_class): + return model_class(config) raise ValueError("Unrecognized configuration class {}".format(config)) @classmethod @@ -294,32 +305,9 @@ class AutoModel(object): if not isinstance(config, PretrainedConfig): config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - if isinstance(config, T5Config): - return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) - elif isinstance(config, DistilBertConfig): - return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) - elif isinstance(config, AlbertConfig): - return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) - elif isinstance(config, CamembertConfig): - return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) - elif isinstance(config, XLMRobertaConfig): - return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) - elif isinstance(config, RobertaConfig): - return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) - elif isinstance(config, BertConfig): - return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) - elif isinstance(config, OpenAIGPTConfig): - return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) - elif isinstance(config, GPT2Config): - return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) - elif isinstance(config, TransfoXLConfig): - return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) - elif isinstance(config, XLNetConfig): - return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) - elif isinstance(config, XLMConfig): - return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) - elif isinstance(config, CTRLConfig): - return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + for config_class, model_class in MODEL_MAPPING.items(): + if isinstance(config, config_class): + return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) raise ValueError( "Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " diff --git a/templates/adding_a_new_model/configuration_xxx.py b/templates/adding_a_new_model/configuration_xxx.py index 647db8207b..3c3c28a8c0 100644 --- a/templates/adding_a_new_model/configuration_xxx.py +++ b/templates/adding_a_new_model/configuration_xxx.py @@ -58,6 +58,7 @@ class XxxConfig(PretrainedConfig): layer_norm_eps: The epsilon used by LayerNorm. """ pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "xxx" def __init__( self,