Update doc of the model page (#5985)
This commit is contained in:
@@ -266,34 +266,43 @@ class ModuleUtilsMixin:
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
r""" Base class for all models.
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
||||
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
||||
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods
|
||||
for loading, downloading and saving models as well as a few methods common to all models to:
|
||||
|
||||
Class attributes (overridden by derived classes):
|
||||
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
|
||||
* resize the input embeddings,
|
||||
* prune heads in the self-attention heads.
|
||||
|
||||
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
|
||||
- ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
|
||||
- ``path``: a path (string) to the TensorFlow checkpoint.
|
||||
Class attributes (overridden by derived classes):
|
||||
- **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
|
||||
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||
- **load_tf_weights** (:obj:`Callable`) -- A python `method` for loading a TensorFlow checkpoint in a
|
||||
PyTorch model, taking as arguments:
|
||||
|
||||
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
|
||||
- **model** (:class:`~transformers.PreTrainedModel`) -- An instance of the model on which to load the
|
||||
TensorFlow checkpoint.
|
||||
- **config** (:class:`~transformers.PreTrainedConfig`) -- An instance of the configuration associated
|
||||
to the model.
|
||||
- **path** (:obj:`str`) -- A path to the TensorFlow checkpoint.
|
||||
|
||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||
derived classes of the same architecture adding modules on top of the base model.
|
||||
"""
|
||||
config_class = None
|
||||
base_model_prefix = ""
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
""" Dummy inputs to do a forward pass in the network.
|
||||
|
||||
Returns:
|
||||
torch.Tensor with dummy inputs
|
||||
:obj:`Dict[str, torch.Tensor]`: The dummy inputs.
|
||||
"""
|
||||
return {"input_ids": torch.tensor(DUMMY_INPUTS)}
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
raise ValueError(
|
||||
@@ -310,13 +319,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
def base_model(self):
|
||||
return getattr(self, self.base_model_prefix, self)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
"""
|
||||
Returns the model's input embeddings.
|
||||
|
||||
Returns:
|
||||
:obj:`nn.Module`:
|
||||
A torch module mapping vocabulary to hidden states.
|
||||
:obj:`nn.Module`: A torch module mapping vocabulary to hidden states.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
@@ -329,8 +337,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
Set model's input embeddings
|
||||
|
||||
Args:
|
||||
value (:obj:`nn.Module`):
|
||||
A module mapping vocabulary to hidden states.
|
||||
value (:obj:`nn.Module`): A module mapping vocabulary to hidden states.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
@@ -338,20 +345,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_embeddings(self):
|
||||
def get_output_embeddings(self) -> nn.Module:
|
||||
"""
|
||||
Returns the model's output embeddings.
|
||||
|
||||
Returns:
|
||||
:obj:`nn.Module`:
|
||||
A torch module mapping hidden states to vocabulary.
|
||||
:obj:`nn.Module`: A torch module mapping hidden states to vocabulary.
|
||||
"""
|
||||
return None # Overwrite for models with output embeddings
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings and the output embeddings.
|
||||
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
|
||||
|
||||
If the :obj:`torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
|
||||
the weights instead.
|
||||
"""
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
@@ -376,18 +383,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
||||
output_embeddings.out_features = input_embeddings.num_embeddings
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
|
||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
||||
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
|
||||
"""
|
||||
Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
|
||||
|
||||
Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
|
||||
|
||||
Arguments:
|
||||
new_num_tokens (:obj:`int`, `optional`):
|
||||
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
|
||||
vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
|
||||
just returns a pointer to the input tokens :obj:`torch.nn.Embedding` module of the model wihtout doing
|
||||
anything.
|
||||
|
||||
new_num_tokens: (`optional`) int:
|
||||
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
|
||||
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
|
||||
|
||||
Return: ``torch.nn.Embeddings``
|
||||
Pointer to the input tokens Embeddings Module of the model
|
||||
Return:
|
||||
:obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
||||
@@ -412,20 +422,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
def _get_resized_embeddings(
|
||||
self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
|
||||
) -> torch.nn.Embedding:
|
||||
""" Build a resized Embedding Module from a provided token Embedding Module.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
"""
|
||||
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
|
||||
initialized vectors at the end. Reducing the size will remove vectors from the end
|
||||
|
||||
Args:
|
||||
old_embeddings: ``torch.nn.Embedding``
|
||||
old_embeddings (:obj:`torch.nn.Embedding`):
|
||||
Old embeddings to be resized.
|
||||
new_num_tokens: (`optional`) int
|
||||
new_num_tokens (:obj:`int`, `optional`):
|
||||
New number of tokens in the embedding matrix.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
If not provided or None: return the provided token Embedding Module.
|
||||
Return: ``torch.nn.Embedding``
|
||||
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
|
||||
|
||||
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
|
||||
vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
|
||||
:obj:`torch.nn.Embedding`` module of the model wihtout doing anything.
|
||||
|
||||
Return:
|
||||
:obj:`torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
|
||||
:obj:`new_num_tokens` is :obj:`None`
|
||||
"""
|
||||
if new_num_tokens is None:
|
||||
return old_embeddings
|
||||
@@ -448,7 +461,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
return new_embeddings
|
||||
|
||||
def init_weights(self):
|
||||
""" Initialize and prunes weights if needed. """
|
||||
"""
|
||||
Initializes and prunes weights if needed.
|
||||
"""
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
|
||||
@@ -459,13 +474,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
# Tie weights if needed
|
||||
self.tie_weights()
|
||||
|
||||
def prune_heads(self, heads_to_prune: Dict):
|
||||
""" Prunes heads of the base model.
|
||||
def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
|
||||
"""
|
||||
Prunes heads of the base model.
|
||||
|
||||
Arguments:
|
||||
|
||||
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
||||
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
|
||||
Arguments:
|
||||
heads_to_prune (:obj:`Dict[int, List[int]]`):
|
||||
Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list
|
||||
of heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will
|
||||
prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
|
||||
"""
|
||||
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
|
||||
for layer, heads in heads_to_prune.items():
|
||||
@@ -475,11 +492,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
self.base_model._prune_heads(heads_to_prune)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a model and its configuration file to a directory, so that it
|
||||
can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
|
||||
|
||||
Arguments:
|
||||
save_directory: directory to which to save.
|
||||
Arguments:
|
||||
save_directory (:obj:`str`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
|
||||
@@ -511,75 +530,110 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
||||
r"""
|
||||
Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
||||
|
||||
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with ``model.train()``
|
||||
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated).
|
||||
To train the model, you should first set it back in training mode with ``model.train()``.
|
||||
|
||||
The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
|
||||
It is up to you to train those weights with a downstream fine-tuning task.
|
||||
The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
|
||||
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
||||
task.
|
||||
|
||||
The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
|
||||
The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
|
||||
weights are discarded.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path: either:
|
||||
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
- None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
|
||||
pretrained_model_name_or_path (:obj:`str`, `optional`):
|
||||
Can be either:
|
||||
|
||||
model_args: (`optional`) Sequence of positional arguments:
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||
- A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
|
||||
``bert-base-uncased``.
|
||||
- A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
|
||||
``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `tensorflow index checkpoint file` (e.g, `./tf_model/model.ckpt.index`). In
|
||||
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
|
||||
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
|
||||
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
- :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
|
||||
arguments ``config`` and ``state_dict``).
|
||||
model_args (sequence of positional arguments, `optional`):
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||
config (:obj:`Union[PretrainedConfig, str]`, `optional`):
|
||||
Can be either:
|
||||
|
||||
config: (`optional`) one of:
|
||||
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
|
||||
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
|
||||
- an instance of a class derived from :class:`~transformers.PretrainedConfig`,
|
||||
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
|
||||
|
||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
||||
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can
|
||||
be automatically loaded when:
|
||||
|
||||
state_dict: (`optional`) dict:
|
||||
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
|
||||
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
||||
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||
- The model is a model provided by the library (loaded with the `shortcut name` string of a
|
||||
pretrained model).
|
||||
- The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
|
||||
by suppling the save directory.
|
||||
- The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
|
||||
configuration JSON file named `config.json` is found in the directory.
|
||||
state_dict (:obj:`Dict[str, torch.Tensor]`, `optional`):
|
||||
A state dictionary to use instead of a state dictionary loaded from saved weights file.
|
||||
|
||||
cache_dir: (`optional`) string:
|
||||
Path to a directory in which a downloaded pre-trained model
|
||||
configuration should be cached if the standard cache should not be used.
|
||||
This option can be used if you want to create a model from a pretrained configuration but load your own
|
||||
weights. In this case though, you should check if using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained` and
|
||||
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||
cache_dir (:obj:`str`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Load the model weights from a TensorFlow checkpoint save file (see docstring of
|
||||
``pretrained_model_name_or_path`` argument).
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies: (:obj:`Dict[str, str], `optional`):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.,
|
||||
:obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each
|
||||
request.
|
||||
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error
|
||||
messages.
|
||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to only look at local files (e.g., not try doanloading the model).
|
||||
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
|
||||
our S3 (faster).
|
||||
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||
automatically loaded:
|
||||
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
|
||||
output_loading_info: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
|
||||
|
||||
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
|
||||
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
|
||||
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
|
||||
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
|
||||
already been done)
|
||||
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
|
||||
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
|
||||
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
|
||||
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
|
||||
attribute will be passed to the underlying model's ``__init__`` function.
|
||||
|
||||
Examples::
|
||||
|
||||
# For example purposes. Not runnable.
|
||||
model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
|
||||
model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
|
||||
from transformers import BertConfig, BertModel
|
||||
# Download model and configuration from S3 and cache.
|
||||
model = BertModel.from_pretrained('bert-base-uncased')
|
||||
# Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
|
||||
model = BertModel.from_pretrained('./test/saved_model/')
|
||||
# Update configuration during loading.
|
||||
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)
|
||||
assert model.config.output_attention == True
|
||||
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||
# Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
|
||||
config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
|
||||
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
state_dict = kwargs.pop("state_dict", None)
|
||||
@@ -1242,18 +1296,23 @@ def apply_chunking_to_forward(
|
||||
chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`.
|
||||
It then applies a layer `forward_fn` to each chunk independently to save memory.
|
||||
If the `forward_fn` is independent across the `chunk_dim` this function will yield the
|
||||
same result as not applying it.
|
||||
This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
|
||||
dimension :obj:`chunk_dim`. It then applies a layer :obj:`forward_fn` to each chunk independently to save memory.
|
||||
|
||||
If the :obj:`forward_fn` is independent across the :obj:`chunk_dim` this function will yield the same result as
|
||||
directly applying :obj:`forward_fn` to :obj:`input_tensors`.
|
||||
|
||||
Args:
|
||||
chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size`
|
||||
chunk_dim: int - the dimension over which the input_tensors should be chunked
|
||||
forward_fn: fn - the forward fn of the model
|
||||
input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked
|
||||
chunk_size (:obj:`int`):
|
||||
The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`.
|
||||
chunk_dim (:obj:`int`):
|
||||
The dimension over which the :obj:`input_tensors` should be chunked.
|
||||
forward_fn (:obj:`Callable[..., torch.Tensor]`):
|
||||
The forward function of the model.
|
||||
input_tensors (:obj:`Tuple[torch.Tensor]`):
|
||||
The input tensors of ``forward_fn`` which will be chunked.
|
||||
Returns:
|
||||
a Tensor with the same shape the foward_fn would have given if applied
|
||||
:obj:`torch.Tensor`: A tensor with the same shape as the :obj:`foward_fn` would have given if applied`.
|
||||
|
||||
|
||||
Examples::
|
||||
|
||||
Reference in New Issue
Block a user