[AutoModel] Split AutoModelWithLMHead into clm, mlm, encoder-decoder (#4933)
* first commit * add new auto models * better naming * fix bert automodel * fix automodel for pretraining * add models to init * fix name typo * fix typo * better naming * future warning instead of depreciation warning
This commit is contained in:
committed by
GitHub
parent
5620033115
commit
86578bb04c
@@ -166,11 +166,17 @@ if is_torch_available():
|
|||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForMaskedLM,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AutoModelForMultipleChoice,
|
AutoModelForMultipleChoice,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
MODEL_FOR_PRETRAINING_MAPPING,
|
MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
@@ -182,6 +188,7 @@ if is_torch_available():
|
|||||||
BertModel,
|
BertModel,
|
||||||
BertForPreTraining,
|
BertForPreTraining,
|
||||||
BertForMaskedLM,
|
BertForMaskedLM,
|
||||||
|
BertLMHeadModel,
|
||||||
BertForNextSentencePrediction,
|
BertForNextSentencePrediction,
|
||||||
BertForSequenceClassification,
|
BertForSequenceClassification,
|
||||||
BertForMultipleChoice,
|
BertForMultipleChoice,
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from .configuration_auto import (
|
from .configuration_auto import (
|
||||||
@@ -58,6 +59,7 @@ from .modeling_bert import (
|
|||||||
BertForQuestionAnswering,
|
BertForQuestionAnswering,
|
||||||
BertForSequenceClassification,
|
BertForSequenceClassification,
|
||||||
BertForTokenClassification,
|
BertForTokenClassification,
|
||||||
|
BertLMHeadModel,
|
||||||
BertModel,
|
BertModel,
|
||||||
)
|
)
|
||||||
from .modeling_camembert import (
|
from .modeling_camembert import (
|
||||||
@@ -210,6 +212,46 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
||||||
|
[
|
||||||
|
(BertConfig, BertLMHeadModel),
|
||||||
|
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
||||||
|
(GPT2Config, GPT2LMHeadModel),
|
||||||
|
(TransfoXLConfig, TransfoXLLMHeadModel),
|
||||||
|
(XLNetConfig, XLNetLMHeadModel),
|
||||||
|
(
|
||||||
|
XLMConfig,
|
||||||
|
XLMWithLMHeadModel,
|
||||||
|
), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now
|
||||||
|
(CTRLConfig, CTRLLMHeadModel),
|
||||||
|
(ReformerConfig, ReformerModelWithLMHead),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||||
|
[
|
||||||
|
(DistilBertConfig, DistilBertForMaskedLM),
|
||||||
|
(AlbertConfig, AlbertForMaskedLM),
|
||||||
|
(CamembertConfig, CamembertForMaskedLM),
|
||||||
|
(XLMRobertaConfig, XLMRobertaForMaskedLM),
|
||||||
|
(LongformerConfig, LongformerForMaskedLM),
|
||||||
|
(RobertaConfig, RobertaForMaskedLM),
|
||||||
|
(BertConfig, BertForMaskedLM),
|
||||||
|
(FlaubertConfig, FlaubertWithLMHeadModel),
|
||||||
|
(XLMConfig, XLMWithLMHeadModel),
|
||||||
|
(ElectraConfig, ElectraForMaskedLM),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||||
|
[
|
||||||
|
(T5Config, T5ForConditionalGeneration),
|
||||||
|
(MarianConfig, MarianMTModel),
|
||||||
|
(BartConfig, BartForConditionalGeneration),
|
||||||
|
(EncoderDecoderConfig, EncoderDecoderModel),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
(DistilBertConfig, DistilBertForSequenceClassification),
|
(DistilBertConfig, DistilBertForSequenceClassification),
|
||||||
@@ -620,6 +662,10 @@ class AutoModelWithLMHead:
|
|||||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
"""
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
|
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||||
if isinstance(config, config_class):
|
if isinstance(config, config_class):
|
||||||
return model_class(config)
|
return model_class(config)
|
||||||
@@ -638,7 +684,7 @@ class AutoModelWithLMHead:
|
|||||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||||
based on the `model_type` property of the config object, or when it's missing,
|
based on the `model_type` property of the config object, or when it's missing,
|
||||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||||
- `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
|
- `t5`: :class:`~transformers.T5ForConditionalGeneration` (T5 model)
|
||||||
- `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
|
- `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
|
||||||
- `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
|
- `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
|
||||||
- `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
|
- `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
|
||||||
@@ -704,6 +750,10 @@ class AutoModelWithLMHead:
|
|||||||
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
config = kwargs.pop("config", None)
|
config = kwargs.pop("config", None)
|
||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
@@ -719,6 +769,412 @@ class AutoModelWithLMHead:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForCausalLM:
|
||||||
|
r"""
|
||||||
|
:class:`~transformers.AutoModelForCausalLM` is a generic model class
|
||||||
|
that will be instantiated as one of the language modeling model classes of the library
|
||||||
|
when created with the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)`
|
||||||
|
class method.
|
||||||
|
|
||||||
|
This class cannot be instantiated using `__init__()` (throws an error).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise EnvironmentError(
|
||||||
|
"AutoModelForCausalLM is designed to be instantiated "
|
||||||
|
"using the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
|
"`AutoModelForCausalLM.from_config(config)` methods."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
r""" Instantiates one of the base model classes of the library
|
||||||
|
from a configuration.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Loading a model from its configuration file does **not** load the model weights.
|
||||||
|
It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
|
||||||
|
the model weights
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (:class:`~transformers.PretrainedConfig`):
|
||||||
|
The model class to instantiate is selected based on the configuration class:
|
||||||
|
|
||||||
|
- isInstance of `bert` configuration class: :class:`~transformers.BertLMHeadModel` (Bert model)
|
||||||
|
- isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
|
||||||
|
- isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
|
||||||
|
- isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
|
||||||
|
- isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
|
||||||
|
- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
|
||||||
|
- isInstance of `reformer` configuration class: :class:`~transformers.ReformerModelWithLMHead` (Reformer model)
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
config = GPT2Config.from_pretrained('gpt2') # Download configuration from S3 and cache.
|
||||||
|
model = AutoModelForCausalLM.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
|
"""
|
||||||
|
for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items():
|
||||||
|
if isinstance(config, config_class):
|
||||||
|
return model_class(config)
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
r""" Instantiates one of the language modeling model classes of the library
|
||||||
|
from a pre-trained model configuration.
|
||||||
|
|
||||||
|
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||||
|
based on the `model_type` property of the config object, or when it's missing,
|
||||||
|
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||||
|
- `bert`: :class:`~transformers.BertLMHeadModel` (Bert model)
|
||||||
|
- `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
|
||||||
|
- `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
|
||||||
|
- `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
|
||||||
|
- `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
|
||||||
|
- `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
|
||||||
|
- `reformer`: :class:`~transformers.ReformerModelWithLMHead` (Google Reformer model)
|
||||||
|
|
||||||
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||||
|
To train the model, you should first set it back in training mode with `model.train()`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_model_name_or_path:
|
||||||
|
Either:
|
||||||
|
|
||||||
|
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||||
|
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
||||||
|
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||||
|
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||||
|
model_args: (`optional`) Sequence of positional arguments:
|
||||||
|
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||||
|
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||||
|
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||||
|
|
||||||
|
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||||
|
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
||||||
|
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
||||||
|
|
||||||
|
state_dict: (`optional`) dict:
|
||||||
|
an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
|
||||||
|
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
||||||
|
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||||
|
cache_dir: (`optional`) string:
|
||||||
|
Path to a directory in which a downloaded pre-trained model
|
||||||
|
configuration should be cached if the standard cache should not be used.
|
||||||
|
force_download: (`optional`) boolean, default False:
|
||||||
|
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||||
|
resume_download: (`optional`) boolean, default False:
|
||||||
|
Do not delete incompletely received file. Attempt to resume the download if such a file exists.
|
||||||
|
proxies: (`optional`) dict, default None:
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||||
|
The proxies are used on each request.
|
||||||
|
output_loading_info: (`optional`) boolean:
|
||||||
|
Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
|
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||||
|
These arguments will be passed to the configuration and the model.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
|
||||||
|
model = AutoModelForCausalLM.from_pretrained('./test/gpt2_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
|
assert model.config.output_attention == True
|
||||||
|
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||||
|
config = AutoConfig.from_json_file('./tf_model/gpt2_tf_model_config.json')
|
||||||
|
model = AutoModelForCausalLM.from_pretrained('./tf_model/gpt2_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
|
"""
|
||||||
|
config = kwargs.pop("config", None)
|
||||||
|
if not isinstance(config, PretrainedConfig):
|
||||||
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
|
for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items():
|
||||||
|
if isinstance(config, config_class):
|
||||||
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForMaskedLM:
|
||||||
|
r"""
|
||||||
|
:class:`~transformers.AutoModelForMaskedLM` is a generic model class
|
||||||
|
that will be instantiated as one of the language modeling model classes of the library
|
||||||
|
when created with the `AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)`
|
||||||
|
class method.
|
||||||
|
|
||||||
|
This class cannot be instantiated using `__init__()` (throws an error).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise EnvironmentError(
|
||||||
|
"AutoModelForMaskedLM is designed to be instantiated "
|
||||||
|
"using the `AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
|
"`AutoModelForMaskedLM.from_config(config)` methods."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
r""" Instantiates one of the base model classes of the library
|
||||||
|
from a configuration.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Loading a model from its configuration file does **not** load the model weights.
|
||||||
|
It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
|
||||||
|
the model weights
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (:class:`~transformers.PretrainedConfig`):
|
||||||
|
The model class to instantiate is selected based on the configuration class:
|
||||||
|
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
|
||||||
|
- isInstance of `longformer` configuration class: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
|
||||||
|
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
|
||||||
|
- isInstance of `bert` configuration class: :class:`~transformers.BertForMaskedLM` (Bert model)
|
||||||
|
- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
|
||||||
|
- isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
|
||||||
|
- isInstance of `xlm-roberta` configuration class: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-Roberta model)
|
||||||
|
- isInstance of `electra` configuration class: :class:`~transformers.ElectraForMaskedLM` (Electra model)
|
||||||
|
- isInstance of `camembert` configuration class: :class:`~transformers.CamembertForMaskedLM` (Camembert model)
|
||||||
|
- isInstance of `albert` configuration class: :class:`~transformers.AlbertForMaskedLM` (Albert model)
|
||||||
|
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
|
model = AutoModelForMaskedLM.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
|
"""
|
||||||
|
for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items():
|
||||||
|
if isinstance(config, config_class):
|
||||||
|
return model_class(config)
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_MASKED_LM_MAPPING.keys())
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
r""" Instantiates one of the language modeling model classes of the library
|
||||||
|
from a pre-trained model configuration.
|
||||||
|
|
||||||
|
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||||
|
based on the `model_type` property of the config object, or when it's missing,
|
||||||
|
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||||
|
- `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
|
||||||
|
- `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
|
||||||
|
- `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
|
||||||
|
- `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
|
||||||
|
- `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
|
||||||
|
- `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
|
||||||
|
- `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
|
||||||
|
- `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
|
||||||
|
- `electra`: :class:`~transformers.ElectraForMaskedLM` (Electra model)
|
||||||
|
- `bert`: :class:`~transformers.BertLMHeadModel` (Bert model)
|
||||||
|
|
||||||
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||||
|
To train the model, you should first set it back in training mode with `model.train()`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_model_name_or_path:
|
||||||
|
Either:
|
||||||
|
|
||||||
|
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||||
|
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
||||||
|
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||||
|
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||||
|
model_args: (`optional`) Sequence of positional arguments:
|
||||||
|
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||||
|
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||||
|
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||||
|
|
||||||
|
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||||
|
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
||||||
|
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
||||||
|
|
||||||
|
state_dict: (`optional`) dict:
|
||||||
|
an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
|
||||||
|
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
||||||
|
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||||
|
cache_dir: (`optional`) string:
|
||||||
|
Path to a directory in which a downloaded pre-trained model
|
||||||
|
configuration should be cached if the standard cache should not be used.
|
||||||
|
force_download: (`optional`) boolean, default False:
|
||||||
|
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||||
|
resume_download: (`optional`) boolean, default False:
|
||||||
|
Do not delete incompletely received file. Attempt to resume the download if such a file exists.
|
||||||
|
proxies: (`optional`) dict, default None:
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||||
|
The proxies are used on each request.
|
||||||
|
output_loading_info: (`optional`) boolean:
|
||||||
|
Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
|
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||||
|
These arguments will be passed to the configuration and the model.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
model = AutoModelForMaskedLM.from_pretrained('bert') # Download model and configuration from S3 and cache.
|
||||||
|
model = AutoModelForMaskedLM.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
|
assert model.config.output_attention == True
|
||||||
|
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||||
|
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
|
||||||
|
model = AutoModelForMaskedLM.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
|
"""
|
||||||
|
config = kwargs.pop("config", None)
|
||||||
|
if not isinstance(config, PretrainedConfig):
|
||||||
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
|
for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items():
|
||||||
|
if isinstance(config, config_class):
|
||||||
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_MASKED_LM_MAPPING.keys())
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForSeq2SeqLM:
|
||||||
|
r"""
|
||||||
|
:class:`~transformers.AutoModelForSeq2SeqLM` is a generic model class
|
||||||
|
that will be instantiated as one of the language modeling model classes of the library
|
||||||
|
when created with the `AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)`
|
||||||
|
class method.
|
||||||
|
|
||||||
|
This class cannot be instantiated using `__init__()` (throws an error).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise EnvironmentError(
|
||||||
|
"AutoModelForSeq2SeqLM is designed to be instantiated "
|
||||||
|
"using the `AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
|
"`AutoModelForSeq2SeqLM.from_config(config)` methods."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
r""" Instantiates one of the base model classes of the library
|
||||||
|
from a configuration.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Loading a model from its configuration file does **not** load the model weights.
|
||||||
|
It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
|
||||||
|
the model weights
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (:class:`~transformers.PretrainedConfig`):
|
||||||
|
The model class to instantiate is selected based on the configuration class:
|
||||||
|
|
||||||
|
- isInstance of `t5` configuration class: :class:`~transformers.T5ForConditionalGeneration` (T5 model)
|
||||||
|
- isInstance of `bart` configuration class: :class:`~transformers.BartForConditionalGeneration` (Bart model)
|
||||||
|
- isInstance of `marian` configuration class: :class:`~transformers.MarianMTModel` (Marian model)
|
||||||
|
- isInstance of `encoder-decoder` configuration class: :class:`~transformers.EncoderDecoderModel` (Encoder Decoder model)
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
config = T5Config.from_pretrained('t5')
|
||||||
|
model = AutoModelForSeq2SeqLM.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
|
"""
|
||||||
|
for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
|
||||||
|
if isinstance(config, config_class):
|
||||||
|
return model_class(config)
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
r""" Instantiates one of the language modeling model classes of the library
|
||||||
|
from a pre-trained model configuration.
|
||||||
|
|
||||||
|
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||||
|
based on the `model_type` property of the config object, or when it's missing,
|
||||||
|
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||||
|
- `t5`: :class:`~transformers.T5ForConditionalGeneration` (T5 model)
|
||||||
|
- `bart`: :class:`~transformers.BartForConditionalGeneration` (Bert model)
|
||||||
|
- `marian`: :class:`~transformers.MarianMTModel` (Marian model)
|
||||||
|
- `encoder-decoder`: :class:`~transformers.EncoderDecoderModel` (Encoder Decoder model)
|
||||||
|
|
||||||
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||||
|
To train the model, you should first set it back in training mode with `model.train()`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_model_name_or_path:
|
||||||
|
Either:
|
||||||
|
|
||||||
|
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||||
|
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
||||||
|
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||||
|
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||||
|
model_args: (`optional`) Sequence of positional arguments:
|
||||||
|
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||||
|
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||||
|
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||||
|
|
||||||
|
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||||
|
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
||||||
|
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
||||||
|
|
||||||
|
state_dict: (`optional`) dict:
|
||||||
|
an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
|
||||||
|
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
||||||
|
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||||
|
cache_dir: (`optional`) string:
|
||||||
|
Path to a directory in which a downloaded pre-trained model
|
||||||
|
configuration should be cached if the standard cache should not be used.
|
||||||
|
force_download: (`optional`) boolean, default False:
|
||||||
|
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||||
|
resume_download: (`optional`) boolean, default False:
|
||||||
|
Do not delete incompletely received file. Attempt to resume the download if such a file exists.
|
||||||
|
proxies: (`optional`) dict, default None:
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||||
|
The proxies are used on each request.
|
||||||
|
output_loading_info: (`optional`) boolean:
|
||||||
|
Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
|
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||||
|
These arguments will be passed to the configuration and the model.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained('t5-base') # Download model and configuration from S3 and cache.
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained('./test/t5_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
|
assert model.config.output_attention == True
|
||||||
|
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||||
|
config = AutoConfig.from_json_file('./tf_model/t5_tf_model_config.json')
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained('./tf_model/t5_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
|
"""
|
||||||
|
config = kwargs.pop("config", None)
|
||||||
|
if not isinstance(config, PretrainedConfig):
|
||||||
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
|
for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
|
||||||
|
if isinstance(config, config_class):
|
||||||
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForSequenceClassification:
|
class AutoModelForSequenceClassification:
|
||||||
r"""
|
r"""
|
||||||
:class:`~transformers.AutoModelForSequenceClassification` is a generic model class
|
:class:`~transformers.AutoModelForSequenceClassification` is a generic model class
|
||||||
|
|||||||
@@ -987,6 +987,9 @@ class BertLMHeadModel(BertPreTrainedModel):
|
|||||||
class BertForMaskedLM(BertPreTrainedModel):
|
class BertForMaskedLM(BertPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
assert (
|
||||||
|
not config.is_decoder
|
||||||
|
), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
|
||||||
|
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
self.cls = BertOnlyMLMHead(config)
|
self.cls = BertOnlyMLMHead(config)
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
instantiated as a transformer architecture with one of the base model
|
instantiated as a transformer architecture with one of the base model
|
||||||
classes of the library as encoder and another one as
|
classes of the library as encoder and another one as
|
||||||
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
|
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
|
||||||
class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
|
class method for the encoder and `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
|
||||||
"""
|
"""
|
||||||
config_class = EncoderDecoderConfig
|
config_class = EncoderDecoderConfig
|
||||||
base_model_prefix = "encoder_decoder"
|
base_model_prefix = "encoder_decoder"
|
||||||
@@ -61,9 +61,9 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
encoder = AutoModel.from_config(config.encoder)
|
encoder = AutoModel.from_config(config.encoder)
|
||||||
|
|
||||||
if decoder is None:
|
if decoder is None:
|
||||||
from transformers import AutoModelWithLMHead
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
decoder = AutoModelWithLMHead.from_config(config.decoder)
|
decoder = AutoModelForCausalLM.from_config(config.decoder)
|
||||||
|
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
@@ -157,7 +157,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
assert (
|
assert (
|
||||||
decoder_pretrained_model_name_or_path is not None
|
decoder_pretrained_model_name_or_path is not None
|
||||||
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
||||||
from .modeling_auto import AutoModelWithLMHead
|
from .modeling_auto import AutoModelForCausalLM
|
||||||
|
|
||||||
if "config" not in kwargs_decoder:
|
if "config" not in kwargs_decoder:
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
@@ -176,7 +176,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||||
|
|
||||||
return cls(encoder=encoder, decoder=decoder)
|
return cls(encoder=encoder, decoder=decoder)
|
||||||
|
|
||||||
|
|||||||
@@ -26,13 +26,20 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
|
GPT2Config,
|
||||||
|
T5Config,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
BertModel,
|
BertModel,
|
||||||
AutoModelForPreTraining,
|
AutoModelForPreTraining,
|
||||||
BertForPreTraining,
|
BertForPreTraining,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
GPT2LMHeadModel,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
|
AutoModelForMaskedLM,
|
||||||
BertForMaskedLM,
|
BertForMaskedLM,
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
T5ForConditionalGeneration,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
BertForSequenceClassification,
|
BertForSequenceClassification,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
@@ -41,6 +48,8 @@ if is_torch_available():
|
|||||||
BertForTokenClassification,
|
BertForTokenClassification,
|
||||||
)
|
)
|
||||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
from transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
from transformers.modeling_auto import (
|
from transformers.modeling_auto import (
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
MODEL_FOR_PRETRAINING_MAPPING,
|
MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
@@ -48,6 +57,9 @@ if is_torch_available():
|
|||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -97,6 +109,45 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
self.assertIsInstance(model, BertForMaskedLM)
|
self.assertIsInstance(model, BertForMaskedLM)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_for_causal_lm(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
self.assertIsInstance(config, GPT2Config)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
model, loading_info = AutoModelForCausalLM.from_pretrained(model_name, output_loading_info=True)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, GPT2LMHeadModel)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_for_masked_lm(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
self.assertIsInstance(config, BertConfig)
|
||||||
|
|
||||||
|
model = AutoModelForMaskedLM.from_pretrained(model_name)
|
||||||
|
model, loading_info = AutoModelForMaskedLM.from_pretrained(model_name, output_loading_info=True)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, BertForMaskedLM)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_for_encoder_decoder_lm(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
self.assertIsInstance(config, T5Config)
|
||||||
|
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
||||||
|
model, loading_info = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_loading_info=True)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, T5ForConditionalGeneration)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_sequence_classification_model_from_pretrained(self):
|
def test_sequence_classification_model_from_pretrained(self):
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -163,6 +214,9 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
)
|
)
|
||||||
|
|
||||||
for mapping in mappings:
|
for mapping in mappings:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
BertConfig,
|
BertConfig,
|
||||||
BertModel,
|
BertModel,
|
||||||
|
BertLMHeadModel,
|
||||||
BertForMaskedLM,
|
BertForMaskedLM,
|
||||||
BertForNextSentencePrediction,
|
BertForNextSentencePrediction,
|
||||||
BertForPreTraining,
|
BertForPreTraining,
|
||||||
@@ -35,7 +36,7 @@ if is_torch_available():
|
|||||||
BertForTokenClassification,
|
BertForTokenClassification,
|
||||||
BertForMultipleChoice,
|
BertForMultipleChoice,
|
||||||
)
|
)
|
||||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BertLMHeadModel
|
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
|
||||||
class BertModelTester:
|
class BertModelTester:
|
||||||
|
|||||||
Reference in New Issue
Block a user