From 86578bb04c9b34f9d8e35cd4fad42a85910dd9e9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 12 Jun 2020 10:01:49 +0200 Subject: [PATCH] [AutoModel] Split AutoModelWithLMHead into clm, mlm, encoder-decoder (#4933) * first commit * add new auto models * better naming * fix bert automodel * fix automodel for pretraining * add models to init * fix name typo * fix typo * better naming * future warning instead of depreciation warning --- src/transformers/__init__.py | 7 + src/transformers/modeling_auto.py | 458 ++++++++++++++++++- src/transformers/modeling_bert.py | 3 + src/transformers/modeling_encoder_decoder.py | 10 +- tests/test_modeling_auto.py | 54 +++ tests/test_modeling_bert.py | 3 +- 6 files changed, 528 insertions(+), 7 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 674ab5850d..67fff0ee55 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -166,11 +166,17 @@ if is_torch_available(): AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelWithLMHead, + AutoModelForCausalLM, + AutoModelForMaskedLM, + AutoModelForSeq2SeqLM, AutoModelForTokenClassification, AutoModelForMultipleChoice, MODEL_MAPPING, MODEL_FOR_PRETRAINING_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, + MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_MASKED_LM_MAPPING, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, @@ -182,6 +188,7 @@ if is_torch_available(): BertModel, BertForPreTraining, BertForMaskedLM, + BertLMHeadModel, BertForNextSentencePrediction, BertForSequenceClassification, BertForMultipleChoice, diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 31236e1e47..4dcfcbf27b 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -16,6 +16,7 @@ import logging +import warnings from collections import OrderedDict from .configuration_auto import ( @@ -58,6 +59,7 @@ from .modeling_bert import ( BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, + BertLMHeadModel, BertModel, ) from .modeling_camembert import ( @@ -210,6 +212,46 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( ] ) +MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( + [ + (BertConfig, BertLMHeadModel), + (OpenAIGPTConfig, OpenAIGPTLMHeadModel), + (GPT2Config, GPT2LMHeadModel), + (TransfoXLConfig, TransfoXLLMHeadModel), + (XLNetConfig, XLNetLMHeadModel), + ( + XLMConfig, + XLMWithLMHeadModel, + ), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now + (CTRLConfig, CTRLLMHeadModel), + (ReformerConfig, ReformerModelWithLMHead), + ] +) + +MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( + [ + (DistilBertConfig, DistilBertForMaskedLM), + (AlbertConfig, AlbertForMaskedLM), + (CamembertConfig, CamembertForMaskedLM), + (XLMRobertaConfig, XLMRobertaForMaskedLM), + (LongformerConfig, LongformerForMaskedLM), + (RobertaConfig, RobertaForMaskedLM), + (BertConfig, BertForMaskedLM), + (FlaubertConfig, FlaubertWithLMHeadModel), + (XLMConfig, XLMWithLMHeadModel), + (ElectraConfig, ElectraForMaskedLM), + ] +) + +MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( + [ + (T5Config, T5ForConditionalGeneration), + (MarianConfig, MarianMTModel), + (BartConfig, BartForConditionalGeneration), + (EncoderDecoderConfig, EncoderDecoderModel), + ] +) + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( [ (DistilBertConfig, DistilBertForSequenceClassification), @@ -620,6 +662,10 @@ class AutoModelWithLMHead: config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` """ + warnings.warn( + "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.", + FutureWarning, + ) for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items(): if isinstance(config, config_class): return model_class(config) @@ -638,7 +684,7 @@ class AutoModelWithLMHead: The `from_pretrained()` method takes care of returning the correct model class instance based on the `model_type` property of the config object, or when it's missing, falling back to using pattern matching on the `pretrained_model_name_or_path` string: - - `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model) + - `t5`: :class:`~transformers.T5ForConditionalGeneration` (T5 model) - `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model) - `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model) - `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model) @@ -704,6 +750,10 @@ class AutoModelWithLMHead: model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ + warnings.warn( + "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.", + FutureWarning, + ) config = kwargs.pop("config", None) if not isinstance(config, PretrainedConfig): config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) @@ -719,6 +769,412 @@ class AutoModelWithLMHead: ) +class AutoModelForCausalLM: + r""" + :class:`~transformers.AutoModelForCausalLM` is a generic model class + that will be instantiated as one of the language modeling model classes of the library + when created with the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` + class method. + + This class cannot be instantiated using `__init__()` (throws an error). + """ + + def __init__(self): + raise EnvironmentError( + "AutoModelForCausalLM is designed to be instantiated " + "using the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` or " + "`AutoModelForCausalLM.from_config(config)` methods." + ) + + @classmethod + def from_config(cls, config): + r""" Instantiates one of the base model classes of the library + from a configuration. + + Note: + Loading a model from its configuration file does **not** load the model weights. + It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load + the model weights + + Args: + config (:class:`~transformers.PretrainedConfig`): + The model class to instantiate is selected based on the configuration class: + + - isInstance of `bert` configuration class: :class:`~transformers.BertLMHeadModel` (Bert model) + - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model) + - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model) + - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model) + - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model) + - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model) + - isInstance of `reformer` configuration class: :class:`~transformers.ReformerModelWithLMHead` (Reformer model) + + Examples:: + + config = GPT2Config.from_pretrained('gpt2') # Download configuration from S3 and cache. + model = AutoModelForCausalLM.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` + """ + for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items(): + if isinstance(config, config_class): + return model_class(config) + raise ValueError( + "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" + "Model type should be one of {}.".format( + config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_CAUSAL_LM_MAPPING.keys()) + ) + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" Instantiates one of the language modeling model classes of the library + from a pre-trained model configuration. + + The `from_pretrained()` method takes care of returning the correct model class instance + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string: + - `bert`: :class:`~transformers.BertLMHeadModel` (Bert model) + - `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model) + - `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model) + - `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model) + - `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model) + - `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model) + - `reformer`: :class:`~transformers.ReformerModelWithLMHead` (Google Reformer model) + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) + To train the model, you should first set it back in training mode with `model.train()` + + Args: + pretrained_model_name_or_path: + Either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + model_args: (`optional`) Sequence of positional arguments: + All remaning positional arguments will be passed to the underlying model's ``__init__`` method + config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: + Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: + + - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or + - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. + - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. + + state_dict: (`optional`) dict: + an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file. + This option can be used if you want to create a model from a pretrained configuration but load your own weights. + In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + resume_download: (`optional`) boolean, default False: + Do not delete incompletely received file. Attempt to resume the download if such a file exists. + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + output_loading_info: (`optional`) boolean: + Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages. + kwargs: (`optional`) Remaining dictionary of keyword arguments: + These arguments will be passed to the configuration and the model. + + Examples:: + + model = AutoModelForCausalLM.from_pretrained('gpt2') # Download model and configuration from S3 and cache. + model = AutoModelForCausalLM.from_pretrained('./test/gpt2_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + assert model.config.output_attention == True + # Loading from a TF checkpoint file instead of a PyTorch model (slower) + config = AutoConfig.from_json_file('./tf_model/gpt2_tf_model_config.json') + model = AutoModelForCausalLM.from_pretrained('./tf_model/gpt2_tf_checkpoint.ckpt.index', from_tf=True, config=config) + + """ + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items(): + if isinstance(config, config_class): + return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + raise ValueError( + "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" + "Model type should be one of {}.".format( + config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_CAUSAL_LM_MAPPING.keys()) + ) + ) + + +class AutoModelForMaskedLM: + r""" + :class:`~transformers.AutoModelForMaskedLM` is a generic model class + that will be instantiated as one of the language modeling model classes of the library + when created with the `AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)` + class method. + + This class cannot be instantiated using `__init__()` (throws an error). + """ + + def __init__(self): + raise EnvironmentError( + "AutoModelForMaskedLM is designed to be instantiated " + "using the `AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)` or " + "`AutoModelForMaskedLM.from_config(config)` methods." + ) + + @classmethod + def from_config(cls, config): + r""" Instantiates one of the base model classes of the library + from a configuration. + + Note: + Loading a model from its configuration file does **not** load the model weights. + It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load + the model weights + + Args: + config (:class:`~transformers.PretrainedConfig`): + The model class to instantiate is selected based on the configuration class: + - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model) + - isInstance of `longformer` configuration class: :class:`~transformers.LongformerForMaskedLM` (Longformer model) + - isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model) + - isInstance of `bert` configuration class: :class:`~transformers.BertForMaskedLM` (Bert model) + - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model) + - isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model) + - isInstance of `xlm-roberta` configuration class: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-Roberta model) + - isInstance of `electra` configuration class: :class:`~transformers.ElectraForMaskedLM` (Electra model) + - isInstance of `camembert` configuration class: :class:`~transformers.CamembertForMaskedLM` (Camembert model) + - isInstance of `albert` configuration class: :class:`~transformers.AlbertForMaskedLM` (Albert model) + + + Examples:: + + config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. + model = AutoModelForMaskedLM.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` + """ + for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items(): + if isinstance(config, config_class): + return model_class(config) + raise ValueError( + "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" + "Model type should be one of {}.".format( + config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_MASKED_LM_MAPPING.keys()) + ) + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" Instantiates one of the language modeling model classes of the library + from a pre-trained model configuration. + + The `from_pretrained()` method takes care of returning the correct model class instance + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string: + - `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model) + - `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model) + - `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model) + - `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model) + - `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model) + - `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model) + - `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model) + - `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model) + - `electra`: :class:`~transformers.ElectraForMaskedLM` (Electra model) + - `bert`: :class:`~transformers.BertLMHeadModel` (Bert model) + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) + To train the model, you should first set it back in training mode with `model.train()` + + Args: + pretrained_model_name_or_path: + Either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + model_args: (`optional`) Sequence of positional arguments: + All remaning positional arguments will be passed to the underlying model's ``__init__`` method + config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: + Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: + + - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or + - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. + - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. + + state_dict: (`optional`) dict: + an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file. + This option can be used if you want to create a model from a pretrained configuration but load your own weights. + In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + resume_download: (`optional`) boolean, default False: + Do not delete incompletely received file. Attempt to resume the download if such a file exists. + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + output_loading_info: (`optional`) boolean: + Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages. + kwargs: (`optional`) Remaining dictionary of keyword arguments: + These arguments will be passed to the configuration and the model. + + Examples:: + + model = AutoModelForMaskedLM.from_pretrained('bert') # Download model and configuration from S3 and cache. + model = AutoModelForMaskedLM.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + assert model.config.output_attention == True + # Loading from a TF checkpoint file instead of a PyTorch model (slower) + config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') + model = AutoModelForMaskedLM.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) + + """ + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items(): + if isinstance(config, config_class): + return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + raise ValueError( + "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" + "Model type should be one of {}.".format( + config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_MASKED_LM_MAPPING.keys()) + ) + ) + + +class AutoModelForSeq2SeqLM: + r""" + :class:`~transformers.AutoModelForSeq2SeqLM` is a generic model class + that will be instantiated as one of the language modeling model classes of the library + when created with the `AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)` + class method. + + This class cannot be instantiated using `__init__()` (throws an error). + """ + + def __init__(self): + raise EnvironmentError( + "AutoModelForSeq2SeqLM is designed to be instantiated " + "using the `AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)` or " + "`AutoModelForSeq2SeqLM.from_config(config)` methods." + ) + + @classmethod + def from_config(cls, config): + r""" Instantiates one of the base model classes of the library + from a configuration. + + Note: + Loading a model from its configuration file does **not** load the model weights. + It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load + the model weights + + Args: + config (:class:`~transformers.PretrainedConfig`): + The model class to instantiate is selected based on the configuration class: + + - isInstance of `t5` configuration class: :class:`~transformers.T5ForConditionalGeneration` (T5 model) + - isInstance of `bart` configuration class: :class:`~transformers.BartForConditionalGeneration` (Bart model) + - isInstance of `marian` configuration class: :class:`~transformers.MarianMTModel` (Marian model) + - isInstance of `encoder-decoder` configuration class: :class:`~transformers.EncoderDecoderModel` (Encoder Decoder model) + + Examples:: + + config = T5Config.from_pretrained('t5') + model = AutoModelForSeq2SeqLM.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` + """ + for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items(): + if isinstance(config, config_class): + return model_class(config) + raise ValueError( + "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" + "Model type should be one of {}.".format( + config.__class__, + cls.__name__, + ", ".join(c.__name__ for c in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()), + ) + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" Instantiates one of the language modeling model classes of the library + from a pre-trained model configuration. + + The `from_pretrained()` method takes care of returning the correct model class instance + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string: + - `t5`: :class:`~transformers.T5ForConditionalGeneration` (T5 model) + - `bart`: :class:`~transformers.BartForConditionalGeneration` (Bert model) + - `marian`: :class:`~transformers.MarianMTModel` (Marian model) + - `encoder-decoder`: :class:`~transformers.EncoderDecoderModel` (Encoder Decoder model) + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) + To train the model, you should first set it back in training mode with `model.train()` + + Args: + pretrained_model_name_or_path: + Either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + model_args: (`optional`) Sequence of positional arguments: + All remaning positional arguments will be passed to the underlying model's ``__init__`` method + config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: + Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: + + - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or + - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. + - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. + + state_dict: (`optional`) dict: + an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file. + This option can be used if you want to create a model from a pretrained configuration but load your own weights. + In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + resume_download: (`optional`) boolean, default False: + Do not delete incompletely received file. Attempt to resume the download if such a file exists. + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + output_loading_info: (`optional`) boolean: + Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages. + kwargs: (`optional`) Remaining dictionary of keyword arguments: + These arguments will be passed to the configuration and the model. + + Examples:: + + model = AutoModelForSeq2SeqLM.from_pretrained('t5-base') # Download model and configuration from S3 and cache. + model = AutoModelForSeq2SeqLM.from_pretrained('./test/t5_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + assert model.config.output_attention == True + # Loading from a TF checkpoint file instead of a PyTorch model (slower) + config = AutoConfig.from_json_file('./tf_model/t5_tf_model_config.json') + model = AutoModelForSeq2SeqLM.from_pretrained('./tf_model/t5_tf_checkpoint.ckpt.index', from_tf=True, config=config) + + """ + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items(): + if isinstance(config, config_class): + return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + raise ValueError( + "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" + "Model type should be one of {}.".format( + config.__class__, + cls.__name__, + ", ".join(c.__name__ for c in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()), + ) + ) + + class AutoModelForSequenceClassification: r""" :class:`~transformers.AutoModelForSequenceClassification` is a generic model class diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 666495f87b..25dd83c0f7 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -987,6 +987,9 @@ class BertLMHeadModel(BertPreTrainedModel): class BertForMaskedLM(BertPreTrainedModel): def __init__(self, config): super().__init__(config) + assert ( + not config.is_decoder + ), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention." self.bert = BertModel(config) self.cls = BertOnlyMLMHead(config) diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index cb44671a9f..a458327e00 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -32,7 +32,7 @@ class EncoderDecoderModel(PreTrainedModel): instantiated as a transformer architecture with one of the base model classes of the library as encoder and another one as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` - class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder. + class method for the encoder and `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` class method for the decoder. """ config_class = EncoderDecoderConfig base_model_prefix = "encoder_decoder" @@ -61,9 +61,9 @@ class EncoderDecoderModel(PreTrainedModel): encoder = AutoModel.from_config(config.encoder) if decoder is None: - from transformers import AutoModelWithLMHead + from transformers import AutoModelForCausalLM - decoder = AutoModelWithLMHead.from_config(config.decoder) + decoder = AutoModelForCausalLM.from_config(config.decoder) self.encoder = encoder self.decoder = decoder @@ -157,7 +157,7 @@ class EncoderDecoderModel(PreTrainedModel): assert ( decoder_pretrained_model_name_or_path is not None ), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined" - from .modeling_auto import AutoModelWithLMHead + from .modeling_auto import AutoModelForCausalLM if "config" not in kwargs_decoder: from transformers import AutoConfig @@ -176,7 +176,7 @@ class EncoderDecoderModel(PreTrainedModel): f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`" ) - decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) return cls(encoder=encoder, decoder=decoder) diff --git a/tests/test_modeling_auto.py b/tests/test_modeling_auto.py index b933d6e5d3..21a8aa4e81 100644 --- a/tests/test_modeling_auto.py +++ b/tests/test_modeling_auto.py @@ -26,13 +26,20 @@ if is_torch_available(): from transformers import ( AutoConfig, BertConfig, + GPT2Config, + T5Config, AutoModel, BertModel, AutoModelForPreTraining, BertForPreTraining, + AutoModelForCausalLM, + GPT2LMHeadModel, AutoModelWithLMHead, + AutoModelForMaskedLM, BertForMaskedLM, RobertaForMaskedLM, + AutoModelForSeq2SeqLM, + T5ForConditionalGeneration, AutoModelForSequenceClassification, BertForSequenceClassification, AutoModelForQuestionAnswering, @@ -41,6 +48,8 @@ if is_torch_available(): BertForTokenClassification, ) from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST + from transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST + from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.modeling_auto import ( MODEL_MAPPING, MODEL_FOR_PRETRAINING_MAPPING, @@ -48,6 +57,9 @@ if is_torch_available(): MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, + MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_MASKED_LM_MAPPING, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, ) @@ -97,6 +109,45 @@ class AutoModelTest(unittest.TestCase): self.assertIsNotNone(model) self.assertIsInstance(model, BertForMaskedLM) + @slow + def test_model_for_causal_lm(self): + logging.basicConfig(level=logging.INFO) + for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, GPT2Config) + + model = AutoModelForCausalLM.from_pretrained(model_name) + model, loading_info = AutoModelForCausalLM.from_pretrained(model_name, output_loading_info=True) + self.assertIsNotNone(model) + self.assertIsInstance(model, GPT2LMHeadModel) + + @slow + def test_model_for_masked_lm(self): + logging.basicConfig(level=logging.INFO) + for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, BertConfig) + + model = AutoModelForMaskedLM.from_pretrained(model_name) + model, loading_info = AutoModelForMaskedLM.from_pretrained(model_name, output_loading_info=True) + self.assertIsNotNone(model) + self.assertIsInstance(model, BertForMaskedLM) + + @slow + def test_model_for_encoder_decoder_lm(self): + logging.basicConfig(level=logging.INFO) + for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, T5Config) + + model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + model, loading_info = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_loading_info=True) + self.assertIsNotNone(model) + self.assertIsInstance(model, T5ForConditionalGeneration) + @slow def test_sequence_classification_model_from_pretrained(self): logging.basicConfig(level=logging.INFO) @@ -163,6 +214,9 @@ class AutoModelTest(unittest.TestCase): MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, + MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_MASKED_LM_MAPPING, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, ) for mapping in mappings: diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index 737a0082c9..5b2397b0ae 100644 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -27,6 +27,7 @@ if is_torch_available(): from transformers import ( BertConfig, BertModel, + BertLMHeadModel, BertForMaskedLM, BertForNextSentencePrediction, BertForPreTraining, @@ -35,7 +36,7 @@ if is_torch_available(): BertForTokenClassification, BertForMultipleChoice, ) - from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BertLMHeadModel + from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST class BertModelTester: