Refactor AutoModel classes and add Flax Auto classes (#11027)
* Refactor AutoModel classes and add Flax Auto classes * Add new objects to the init * Fix hubconf and sort models * Fix TF tests * Missing coma * Update src/transformers/models/auto/auto_factory.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Fix init * Fix dummies * Other init to fix Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -189,3 +189,52 @@ FlaxAutoModel
|
|||||||
|
|
||||||
.. autoclass:: transformers.FlaxAutoModel
|
.. autoclass:: transformers.FlaxAutoModel
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
FlaxAutoModelForPreTraining
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxAutoModelForPreTraining
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
FlaxAutoModelForMaskedLM
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxAutoModelForMaskedLM
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
FlaxAutoModelForSequenceClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxAutoModelForSequenceClassification
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
FlaxAutoModelForQuestionAnswering
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxAutoModelForQuestionAnswering
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
FlaxAutoModelForTokenClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxAutoModelForTokenClassification
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
FlaxAutoModelForMultipleChoice
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxAutoModelForMultipleChoice
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
FlaxAutoModelForNextSentencePrediction
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxAutoModelForNextSentencePrediction
|
||||||
|
:members:
|
||||||
|
|||||||
36
hubconf.py
36
hubconf.py
@@ -22,9 +22,10 @@ sys.path.append(SRC_DIR)
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForMaskedLM,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoModelWithLMHead,
|
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
)
|
)
|
||||||
@@ -86,22 +87,41 @@ def model(*args, **kwargs):
|
|||||||
return AutoModel.from_pretrained(*args, **kwargs)
|
return AutoModel.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(AutoModelWithLMHead.__doc__)
|
@add_start_docstrings(AutoModelForCausalLM.__doc__)
|
||||||
def modelWithLMHead(*args, **kwargs):
|
def modelForCausalLM(*args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
# Using torch.hub !
|
# Using torch.hub !
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased') # Download model and configuration from huggingface.co and cache.
|
model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', 'gpt2') # Download model and configuration from huggingface.co and cache.
|
||||||
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', './test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased', output_attentions=True) # Update configuration during loading
|
model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', 'gpt2', output_attentions=True) # Update configuration during loading
|
||||||
|
assert model.config.output_attentions == True
|
||||||
|
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||||
|
config = AutoConfig.from_pretrained('./tf_model/gpt_tf_model_config.json')
|
||||||
|
model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', './tf_model/gpt_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
|
"""
|
||||||
|
return AutoModelForCausalLM.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(AutoModelForMaskedLM.__doc__)
|
||||||
|
def modelForMaskedLM(*args, **kwargs):
|
||||||
|
r"""
|
||||||
|
# Using torch.hub !
|
||||||
|
import torch
|
||||||
|
|
||||||
|
model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', 'bert-base-uncased') # Download model and configuration from huggingface.co and cache.
|
||||||
|
model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
|
model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', 'bert-base-uncased', output_attentions=True) # Update configuration during loading
|
||||||
assert model.config.output_attentions == True
|
assert model.config.output_attentions == True
|
||||||
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||||
config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json')
|
config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json')
|
||||||
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return AutoModelWithLMHead.from_pretrained(*args, **kwargs)
|
|
||||||
|
return AutoModelForMaskedLM.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(AutoModelForSequenceClassification.__doc__)
|
@add_start_docstrings(AutoModelForSequenceClassification.__doc__)
|
||||||
|
|||||||
@@ -1300,7 +1300,26 @@ else:
|
|||||||
# FLAX-backed objects
|
# FLAX-backed objects
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
|
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
|
||||||
_import_structure["models.auto"].extend(["FLAX_MODEL_MAPPING", "FlaxAutoModel"])
|
_import_structure["models.auto"].extend(
|
||||||
|
[
|
||||||
|
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||||
|
"FLAX_MODEL_MAPPING",
|
||||||
|
"FlaxAutoModel",
|
||||||
|
"FlaxAutoModelForMaskedLM",
|
||||||
|
"FlaxAutoModelForMultipleChoice",
|
||||||
|
"FlaxAutoModelForNextSentencePrediction",
|
||||||
|
"FlaxAutoModelForPreTraining",
|
||||||
|
"FlaxAutoModelForQuestionAnswering",
|
||||||
|
"FlaxAutoModelForSequenceClassification",
|
||||||
|
"FlaxAutoModelForTokenClassification",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.bert"].extend(
|
_import_structure["models.bert"].extend(
|
||||||
[
|
[
|
||||||
"FlaxBertForMaskedLM",
|
"FlaxBertForMaskedLM",
|
||||||
@@ -2410,7 +2429,24 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from .modeling_flax_utils import FlaxPreTrainedModel
|
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||||
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
from .models.auto import (
|
||||||
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
|
FLAX_MODEL_MAPPING,
|
||||||
|
FlaxAutoModel,
|
||||||
|
FlaxAutoModelForMaskedLM,
|
||||||
|
FlaxAutoModelForMultipleChoice,
|
||||||
|
FlaxAutoModelForNextSentencePrediction,
|
||||||
|
FlaxAutoModelForPreTraining,
|
||||||
|
FlaxAutoModelForQuestionAnswering,
|
||||||
|
FlaxAutoModelForSequenceClassification,
|
||||||
|
FlaxAutoModelForTokenClassification,
|
||||||
|
)
|
||||||
from .models.bert import (
|
from .models.bert import (
|
||||||
FlaxBertForMaskedLM,
|
FlaxBertForMaskedLM,
|
||||||
FlaxBertForMultipleChoice,
|
FlaxBertForMultipleChoice,
|
||||||
|
|||||||
@@ -82,7 +82,24 @@ if is_tf_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
_import_structure["modeling_flax_auto"] = ["FLAX_MODEL_MAPPING", "FlaxAutoModel"]
|
_import_structure["modeling_flax_auto"] = [
|
||||||
|
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||||
|
"FLAX_MODEL_MAPPING",
|
||||||
|
"FlaxAutoModel",
|
||||||
|
"FlaxAutoModelForMaskedLM",
|
||||||
|
"FlaxAutoModelForMultipleChoice",
|
||||||
|
"FlaxAutoModelForNextSentencePrediction",
|
||||||
|
"FlaxAutoModelForPreTraining",
|
||||||
|
"FlaxAutoModelForQuestionAnswering",
|
||||||
|
"FlaxAutoModelForSequenceClassification",
|
||||||
|
"FlaxAutoModelForTokenClassification",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -145,7 +162,24 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from .modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
from .modeling_flax_auto import (
|
||||||
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
|
FLAX_MODEL_MAPPING,
|
||||||
|
FlaxAutoModel,
|
||||||
|
FlaxAutoModelForMaskedLM,
|
||||||
|
FlaxAutoModelForMultipleChoice,
|
||||||
|
FlaxAutoModelForNextSentencePrediction,
|
||||||
|
FlaxAutoModelForPreTraining,
|
||||||
|
FlaxAutoModelForQuestionAnswering,
|
||||||
|
FlaxAutoModelForSequenceClassification,
|
||||||
|
FlaxAutoModelForTokenClassification,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
420
src/transformers/models/auto/auto_factory.py
Normal file
420
src/transformers/models/auto/auto_factory.py
Normal file
@@ -0,0 +1,420 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Factory function to build auto-model classes."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import types
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
|
||||||
|
|
||||||
|
|
||||||
|
CLASS_DOCSTRING = """
|
||||||
|
This is a generic model class that will be instantiated as one of the model classes of the library when created
|
||||||
|
with the :meth:`~transformers.BaseAutoModelClass.from_pretrained` class method or the
|
||||||
|
:meth:`~transformers.BaseAutoModelClass.from_config` class method.
|
||||||
|
|
||||||
|
This class cannot be instantiated directly using ``__init__()`` (throws an error).
|
||||||
|
"""
|
||||||
|
|
||||||
|
FROM_CONFIG_DOCSTRING = """
|
||||||
|
Instantiates one of the model classes of the library from a configuration.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Loading a model from its configuration file does **not** load the model weights. It only affects the
|
||||||
|
model's configuration. Use :meth:`~transformers.BaseAutoModelClass.from_pretrained` to load the model
|
||||||
|
weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (:class:`~transformers.PretrainedConfig`):
|
||||||
|
The model class to instantiate is selected based on the configuration class:
|
||||||
|
|
||||||
|
List options
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import AutoConfig, BaseAutoModelClass
|
||||||
|
>>> # Download configuration from huggingface.co and cache.
|
||||||
|
>>> config = AutoConfig.from_pretrained('checkpoint_placeholder')
|
||||||
|
>>> model = BaseAutoModelClass.from_config(config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
FROM_PRETRAINED_TORCH_DOCSTRING = """
|
||||||
|
Instantiate one of the model classes of the library from a pretrained model.
|
||||||
|
|
||||||
|
The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
|
||||||
|
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing,
|
||||||
|
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:
|
||||||
|
|
||||||
|
List options
|
||||||
|
|
||||||
|
The model is set in evaluation mode by default using ``model.eval()`` (so for instance, dropout modules are
|
||||||
|
deactivated). To train the model, you should first set it back in training mode with ``model.train()``
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||||
|
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||||
|
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||||
|
- A path to a `directory` containing model weights saved using
|
||||||
|
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||||
|
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
|
||||||
|
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
|
||||||
|
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
|
||||||
|
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||||
|
model_args (additional positional arguments, `optional`):
|
||||||
|
Will be passed along to the underlying model ``__init__()`` method.
|
||||||
|
config (:class:`~transformers.PretrainedConfig`, `optional`):
|
||||||
|
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
||||||
|
be automatically loaded when:
|
||||||
|
|
||||||
|
- The model is a model provided by the library (loaded with the `model id` string of a pretrained
|
||||||
|
model).
|
||||||
|
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
|
||||||
|
by supplying the save directory.
|
||||||
|
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
||||||
|
configuration JSON file named `config.json` is found in the directory.
|
||||||
|
state_dict (`Dict[str, torch.Tensor]`, `optional`):
|
||||||
|
A state dictionary to use instead of a state dictionary loaded from saved weights file.
|
||||||
|
|
||||||
|
This option can be used if you want to create a model from a pretrained configuration but load your own
|
||||||
|
weights. In this case though, you should check if using
|
||||||
|
:func:`~transformers.PreTrainedModel.save_pretrained` and
|
||||||
|
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||||
|
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||||
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||||
|
standard cache should not be used.
|
||||||
|
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Load the model weights from a TensorFlow checkpoint save file (see docstring of
|
||||||
|
``pretrained_model_name_or_path`` argument).
|
||||||
|
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
|
cached versions if they exist.
|
||||||
|
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||||
|
file exists.
|
||||||
|
proxies (:obj:`Dict[str, str], `optional`):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
|
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to only look at local files (e.g., not try downloading the model).
|
||||||
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
|
kwargs (additional keyword arguments, `optional`):
|
||||||
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||||
|
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||||
|
automatically loaded:
|
||||||
|
|
||||||
|
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
|
||||||
|
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
|
||||||
|
already been done)
|
||||||
|
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
|
||||||
|
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
|
||||||
|
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
|
||||||
|
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
|
||||||
|
attribute will be passed to the underlying model's ``__init__`` function.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import AutoConfig, BaseAutoModelClass
|
||||||
|
|
||||||
|
>>> # Download model and configuration from huggingface.co and cache.
|
||||||
|
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder')
|
||||||
|
|
||||||
|
>>> # Update configuration during loading
|
||||||
|
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True)
|
||||||
|
>>> model.config.output_attentions
|
||||||
|
True
|
||||||
|
|
||||||
|
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||||
|
>>> config = AutoConfig.from_pretrained('./tf_model/shortcut_placeholder_tf_model_config.json')
|
||||||
|
>>> model = BaseAutoModelClass.from_pretrained('./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
FROM_PRETRAINED_TF_DOCSTRING = """
|
||||||
|
Instantiate one of the model classes of the library from a pretrained model.
|
||||||
|
|
||||||
|
The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
|
||||||
|
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing,
|
||||||
|
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:
|
||||||
|
|
||||||
|
List options
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||||
|
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||||
|
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||||
|
- A path to a `directory` containing model weights saved using
|
||||||
|
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||||
|
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
|
||||||
|
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
|
||||||
|
as ``config`` argument. This loading path is slower than converting the PyTorch model in a
|
||||||
|
TensorFlow model using the provided conversion scripts and loading the TensorFlow model
|
||||||
|
afterwards.
|
||||||
|
model_args (additional positional arguments, `optional`):
|
||||||
|
Will be passed along to the underlying model ``__init__()`` method.
|
||||||
|
config (:class:`~transformers.PretrainedConfig`, `optional`):
|
||||||
|
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
||||||
|
be automatically loaded when:
|
||||||
|
|
||||||
|
- The model is a model provided by the library (loaded with the `model id` string of a pretrained
|
||||||
|
model).
|
||||||
|
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
|
||||||
|
by supplying the save directory.
|
||||||
|
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
||||||
|
configuration JSON file named `config.json` is found in the directory.
|
||||||
|
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||||
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||||
|
standard cache should not be used.
|
||||||
|
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Load the model weights from a PyTorch checkpoint save file (see docstring of
|
||||||
|
``pretrained_model_name_or_path`` argument).
|
||||||
|
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
|
cached versions if they exist.
|
||||||
|
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||||
|
file exists.
|
||||||
|
proxies (:obj:`Dict[str, str], `optional`):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
|
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to only look at local files (e.g., not try downloading the model).
|
||||||
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
|
kwargs (additional keyword arguments, `optional`):
|
||||||
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||||
|
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||||
|
automatically loaded:
|
||||||
|
|
||||||
|
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
|
||||||
|
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
|
||||||
|
already been done)
|
||||||
|
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
|
||||||
|
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
|
||||||
|
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
|
||||||
|
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
|
||||||
|
attribute will be passed to the underlying model's ``__init__`` function.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import AutoConfig, BaseAutoModelClass
|
||||||
|
|
||||||
|
>>> # Download model and configuration from huggingface.co and cache.
|
||||||
|
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder')
|
||||||
|
|
||||||
|
>>> # Update configuration during loading
|
||||||
|
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True)
|
||||||
|
>>> model.config.output_attentions
|
||||||
|
True
|
||||||
|
|
||||||
|
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
|
||||||
|
>>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json')
|
||||||
|
>>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
FROM_PRETRAINED_FLAX_DOCSTRING = """
|
||||||
|
Instantiate one of the model classes of the library from a pretrained model.
|
||||||
|
|
||||||
|
The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
|
||||||
|
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing,
|
||||||
|
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:
|
||||||
|
|
||||||
|
List options
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||||
|
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||||
|
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||||
|
- A path to a `directory` containing model weights saved using
|
||||||
|
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||||
|
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
|
||||||
|
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
|
||||||
|
as ``config`` argument. This loading path is slower than converting the PyTorch model in a
|
||||||
|
TensorFlow model using the provided conversion scripts and loading the TensorFlow model
|
||||||
|
afterwards.
|
||||||
|
model_args (additional positional arguments, `optional`):
|
||||||
|
Will be passed along to the underlying model ``__init__()`` method.
|
||||||
|
config (:class:`~transformers.PretrainedConfig`, `optional`):
|
||||||
|
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
||||||
|
be automatically loaded when:
|
||||||
|
|
||||||
|
- The model is a model provided by the library (loaded with the `model id` string of a pretrained
|
||||||
|
model).
|
||||||
|
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
|
||||||
|
by supplying the save directory.
|
||||||
|
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
||||||
|
configuration JSON file named `config.json` is found in the directory.
|
||||||
|
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||||
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||||
|
standard cache should not be used.
|
||||||
|
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Load the model weights from a PyTorch checkpoint save file (see docstring of
|
||||||
|
``pretrained_model_name_or_path`` argument).
|
||||||
|
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
|
cached versions if they exist.
|
||||||
|
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||||
|
file exists.
|
||||||
|
proxies (:obj:`Dict[str, str], `optional`):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
|
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to only look at local files (e.g., not try downloading the model).
|
||||||
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
|
kwargs (additional keyword arguments, `optional`):
|
||||||
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||||
|
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||||
|
automatically loaded:
|
||||||
|
|
||||||
|
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
|
||||||
|
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
|
||||||
|
already been done)
|
||||||
|
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
|
||||||
|
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
|
||||||
|
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
|
||||||
|
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
|
||||||
|
attribute will be passed to the underlying model's ``__init__`` function.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import AutoConfig, BaseAutoModelClass
|
||||||
|
|
||||||
|
>>> # Download model and configuration from huggingface.co and cache.
|
||||||
|
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder')
|
||||||
|
|
||||||
|
>>> # Update configuration during loading
|
||||||
|
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True)
|
||||||
|
>>> model.config.output_attentions
|
||||||
|
True
|
||||||
|
|
||||||
|
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
|
||||||
|
>>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json')
|
||||||
|
>>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseAutoModelClass:
|
||||||
|
# Base class for auto models.
|
||||||
|
_model_mapping = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{self.__class__.__name__} is designed to be instantiated "
|
||||||
|
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
|
f"`{self.__class__.__name__}.from_config(config)` methods."
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_config(cls, config, **kwargs):
|
||||||
|
if type(config) in cls._model_mapping.keys():
|
||||||
|
return cls._model_mapping[type(config)](config, **kwargs)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||||
|
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
config = kwargs.pop("config", None)
|
||||||
|
kwargs["_from_auto"] = True
|
||||||
|
if not isinstance(config, PretrainedConfig):
|
||||||
|
config, kwargs = AutoConfig.from_pretrained(
|
||||||
|
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if type(config) in cls._model_mapping.keys():
|
||||||
|
return cls._model_mapping[type(config)].from_pretrained(
|
||||||
|
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||||
|
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_func(f):
|
||||||
|
""" Returns a copy of a function f."""
|
||||||
|
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
|
||||||
|
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
|
||||||
|
g = functools.update_wrapper(g, f)
|
||||||
|
g.__kwdefaults__ = f.__kwdefaults__
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def insert_head_doc(docstring, head_doc=""):
|
||||||
|
if len(head_doc) > 0:
|
||||||
|
return docstring.replace(
|
||||||
|
"one of the model classes of the library ",
|
||||||
|
f"one of the model classes of the library (with a {head_doc} head) ",
|
||||||
|
)
|
||||||
|
return docstring.replace(
|
||||||
|
"one of the model classes of the library ", "one of the base model classes of the library "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-cased", head_doc=""):
|
||||||
|
# Create a new class with the right name from the base class
|
||||||
|
new_class = types.new_class(name, (_BaseAutoModelClass,))
|
||||||
|
new_class._model_mapping = model_mapping
|
||||||
|
class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
|
||||||
|
new_class.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
|
||||||
|
|
||||||
|
# Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
|
||||||
|
# have a specific docstrings for them.
|
||||||
|
from_config = copy_func(_BaseAutoModelClass.from_config)
|
||||||
|
from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc)
|
||||||
|
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
|
||||||
|
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
|
||||||
|
from_config.__doc__ = from_config_docstring
|
||||||
|
from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config)
|
||||||
|
new_class.from_config = classmethod(from_config)
|
||||||
|
|
||||||
|
if name.startswith("TF"):
|
||||||
|
from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
|
||||||
|
elif name.startswith("Flax"):
|
||||||
|
from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING
|
||||||
|
else:
|
||||||
|
from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING
|
||||||
|
from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
|
||||||
|
from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc)
|
||||||
|
from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name)
|
||||||
|
from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
|
||||||
|
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
|
||||||
|
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
|
||||||
|
from_pretrained.__doc__ = from_pretrained_docstring
|
||||||
|
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
|
||||||
|
new_class.from_pretrained = classmethod(from_pretrained)
|
||||||
|
return new_class
|
||||||
@@ -256,8 +256,8 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
|
|||||||
if config in config_to_class
|
if config in config_to_class
|
||||||
}
|
}
|
||||||
lines = [
|
lines = [
|
||||||
f"{indent}- **{model_type}** -- :class:`~transformers.{cls_name}` ({MODEL_NAMES_MAPPING[model_type]} model)"
|
f"{indent}- **{model_type}** -- :class:`~transformers.{model_type_to_name[model_type]}` ({MODEL_NAMES_MAPPING[model_type]} model)"
|
||||||
for model_type, cls_name in model_type_to_name.items()
|
for model_type in sorted(model_type_to_name.keys())
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()}
|
config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()}
|
||||||
@@ -265,8 +265,8 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
|
|||||||
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
|
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
|
||||||
}
|
}
|
||||||
lines = [
|
lines = [
|
||||||
f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{cls_name}` ({config_to_model_name[config_name]} model)"
|
f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{config_to_name[config_name]}` ({config_to_model_name[config_name]} model)"
|
||||||
for config_name, cls_name in config_to_name.items()
|
for config_name in sorted(config_to_name.keys())
|
||||||
]
|
]
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -17,11 +17,20 @@
|
|||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..bert.modeling_flax_bert import FlaxBertModel
|
from ..bert.modeling_flax_bert import (
|
||||||
|
FlaxBertForMaskedLM,
|
||||||
|
FlaxBertForMultipleChoice,
|
||||||
|
FlaxBertForNextSentencePrediction,
|
||||||
|
FlaxBertForPreTraining,
|
||||||
|
FlaxBertForQuestionAnswering,
|
||||||
|
FlaxBertForSequenceClassification,
|
||||||
|
FlaxBertForTokenClassification,
|
||||||
|
FlaxBertModel,
|
||||||
|
)
|
||||||
from ..roberta.modeling_flax_roberta import FlaxRobertaModel
|
from ..roberta.modeling_flax_roberta import FlaxRobertaModel
|
||||||
from .configuration_auto import AutoConfig, BertConfig, RobertaConfig
|
from .auto_factory import auto_class_factory
|
||||||
|
from .configuration_auto import BertConfig, RobertaConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -29,140 +38,90 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
FLAX_MODEL_MAPPING = OrderedDict(
|
FLAX_MODEL_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
|
# Base model mapping
|
||||||
(RobertaConfig, FlaxRobertaModel),
|
(RobertaConfig, FlaxRobertaModel),
|
||||||
(BertConfig, FlaxBertModel),
|
(BertConfig, FlaxBertModel),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||||
|
[
|
||||||
|
# Model for pre-training mapping
|
||||||
|
(BertConfig, FlaxBertForPreTraining),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
class FlaxAutoModel(object):
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||||
r"""
|
[
|
||||||
:class:`~transformers.FlaxAutoModel` is a generic model class that will be instantiated as one of the base model
|
# Model for Masked LM mapping
|
||||||
classes of the library when created with the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or the
|
(BertConfig, FlaxBertForMaskedLM),
|
||||||
`FlaxAutoModel.from_config(config)` class methods.
|
]
|
||||||
|
)
|
||||||
|
|
||||||
This class cannot be instantiated using `__init__()` (throws an error).
|
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||||
"""
|
[
|
||||||
|
# Model for Sequence Classification mapping
|
||||||
|
(BertConfig, FlaxBertForSequenceClassification),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self):
|
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||||
raise EnvironmentError(
|
[
|
||||||
"FlaxAutoModel is designed to be instantiated "
|
# Model for Question Answering mapping
|
||||||
"using the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or "
|
(BertConfig, FlaxBertForQuestionAnswering),
|
||||||
"`FlaxAutoModel.from_config(config)` methods."
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||||
def from_config(cls, config):
|
[
|
||||||
r"""
|
# Model for Token Classification mapping
|
||||||
Instantiates one of the base model classes of the library from a configuration.
|
(BertConfig, FlaxBertForTokenClassification),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
Args:
|
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||||
config (:class:`~transformers.PretrainedConfig`):
|
[
|
||||||
The model class to instantiate is selected based on the configuration class:
|
# Model for Multiple Choice mapping
|
||||||
|
(BertConfig, FlaxBertForMultipleChoice),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
- isInstance of `roberta` configuration class: :class:`~transformers.FlaxRobertaModel` (RoBERTa model)
|
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
|
||||||
- isInstance of `bert` configuration class: :class:`~transformers.FlaxBertModel` (Bert model
|
[
|
||||||
|
(BertConfig, FlaxBertForNextSentencePrediction),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
Examples::
|
FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING)
|
||||||
|
|
||||||
config = BertConfig.from_pretrained('bert-base-uncased')
|
FlaxAutoModelForPreTraining = auto_class_factory(
|
||||||
# Download configuration from huggingface.co and cache.
|
"FlaxAutoModelForPreTraining", FLAX_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
|
||||||
model = FlaxAutoModel.from_config(config)
|
)
|
||||||
# E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
|
||||||
"""
|
|
||||||
for config_class, model_class in FLAX_MODEL_MAPPING.items():
|
|
||||||
if isinstance(config, config_class):
|
|
||||||
return model_class(config)
|
|
||||||
raise ValueError(
|
|
||||||
f"Unrecognized configuration class {config.__class__} "
|
|
||||||
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
|
||||||
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
FlaxAutoModelForMaskedLM = auto_class_factory(
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
"FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
|
||||||
r"""
|
)
|
||||||
Instantiates one of the base model classes of the library from a pre-trained model configuration.
|
|
||||||
|
|
||||||
The `from_pretrained()` method takes care of returning the correct model class instance based on the
|
FlaxAutoModelForSequenceClassification = auto_class_factory(
|
||||||
`model_type` property of the config object, or when it's missing, falling back to using pattern matching on the
|
"AFlaxutoModelForSequenceClassification",
|
||||||
`pretrained_model_name_or_path` string.
|
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
head_doc="sequence classification",
|
||||||
|
)
|
||||||
|
|
||||||
The base model class to instantiate is selected as the first pattern matching in the
|
FlaxAutoModelForQuestionAnswering = auto_class_factory(
|
||||||
`pretrained_model_name_or_path` string (in the following order):
|
"FlaxAutoModelForQuestionAnswering", FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering"
|
||||||
|
)
|
||||||
|
|
||||||
- contains `roberta`: :class:`~transformers.FlaxRobertaModel` (RoBERTa model)
|
FlaxAutoModelForTokenClassification = auto_class_factory(
|
||||||
- contains `bert`: :class:`~transformers.FlaxBertModel` (Bert model)
|
"FlaxAutoModelForTokenClassification", FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
|
||||||
|
)
|
||||||
|
|
||||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) To
|
FlaxAutoModelForMultipleChoice = auto_class_factory(
|
||||||
train the model, you should first set it back in training mode with `model.train()`
|
"AutoModelForMultipleChoice", FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice"
|
||||||
|
)
|
||||||
|
|
||||||
Args:
|
FlaxAutoModelForNextSentencePrediction = auto_class_factory(
|
||||||
pretrained_model_name_or_path: either:
|
"FlaxAutoModelForNextSentencePrediction",
|
||||||
|
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||||
- a string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. Valid
|
head_doc="next sentence prediction",
|
||||||
model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or
|
)
|
||||||
organization name, like ``dbmdz/bert-base-german-cased``.
|
|
||||||
- a path to a `directory` containing model weights saved using
|
|
||||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
|
||||||
- a path or url to a `pytorch index checkpoint file` (e.g. `./pt_model/pytorch_model.bin`). In this
|
|
||||||
case, ``from_pt`` should be set to True and a configuration object should be provided as ``config``
|
|
||||||
argument.
|
|
||||||
|
|
||||||
model_args: (`optional`) Sequence of positional arguments:
|
|
||||||
All remaining positional arguments will be passed to the underlying model's ``__init__`` method
|
|
||||||
|
|
||||||
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
|
||||||
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
|
||||||
be automatically loaded when:
|
|
||||||
|
|
||||||
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a
|
|
||||||
pretrained model), or
|
|
||||||
- the model was saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and is reloaded
|
|
||||||
by supplying the save directory.
|
|
||||||
- the model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
|
||||||
configuration JSON file named `config.json` is found in the directory.
|
|
||||||
|
|
||||||
cache_dir: (`optional`) string:
|
|
||||||
Path to a directory in which a downloaded pre-trained model configuration should be cached if the
|
|
||||||
standard cache should not be used.
|
|
||||||
|
|
||||||
force_download: (`optional`) boolean, default False:
|
|
||||||
Force to (re-)download the model weights and configuration files and override the cached versions if
|
|
||||||
they exists.
|
|
||||||
|
|
||||||
resume_download: (`optional`) boolean, default False:
|
|
||||||
Do not delete incompletely received file. Attempt to resume the download if such a file exists.
|
|
||||||
|
|
||||||
proxies: (`optional`) dict, default None:
|
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
|
|
||||||
'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
|
|
||||||
|
|
||||||
output_loading_info: (`optional`) boolean:
|
|
||||||
Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error
|
|
||||||
messages.
|
|
||||||
|
|
||||||
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
|
||||||
These arguments will be passed to the configuration and the model.
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
model = FlaxAutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from huggingface.co and cache.
|
|
||||||
model = FlaxAutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
|
||||||
assert model.config.output_attention == True
|
|
||||||
|
|
||||||
"""
|
|
||||||
config = kwargs.pop("config", None)
|
|
||||||
if not isinstance(config, PretrainedConfig):
|
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
||||||
|
|
||||||
for config_class, model_class in FLAX_MODEL_MAPPING.items():
|
|
||||||
if isinstance(config, config_class):
|
|
||||||
return model_class.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, _from_auto=True, **kwargs
|
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
f"Unrecognized configuration class {config.__class__} "
|
|
||||||
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
|
||||||
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}"
|
|
||||||
)
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -11,6 +11,27 @@ class FlaxPreTrainedModel:
|
|||||||
requires_flax(self)
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_MODEL_FOR_PRETRAINING_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
FLAX_MODEL_MAPPING = None
|
FLAX_MODEL_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
@@ -23,6 +44,69 @@ class FlaxAutoModel:
|
|||||||
requires_flax(self)
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxAutoModelForMaskedLM:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxAutoModelForMultipleChoice:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxAutoModelForNextSentencePrediction:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxAutoModelForPreTraining:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxAutoModelForQuestionAnswering:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxAutoModelForSequenceClassification:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxAutoModelForTokenClassification:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertForMaskedLM:
|
class FlaxBertForMaskedLM:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_flax(self)
|
requires_flax(self)
|
||||||
|
|||||||
Reference in New Issue
Block a user