From 3b44aa935a4d8f1b0e93a23070d97be6b9c9506b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 24 Jul 2020 09:16:28 -0400 Subject: [PATCH] Model utils doc (#6005) * Document TF modeling utils * Document all model utils --- docs/source/index.rst | 3 +- docs/source/internal/modeling_utils.rst | 88 +++++ docs/source/main_classes/model.rst | 29 +- setup.cfg | 2 +- src/transformers/configuration_utils.py | 2 +- src/transformers/modeling_tf_utils.py | 267 +++++++++++---- src/transformers/modeling_utils.py | 429 ++++++++++++++++-------- 7 files changed, 601 insertions(+), 219 deletions(-) create mode 100644 docs/source/internal/modeling_utils.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index bcc46a01d2..c5eb3283b0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -177,9 +177,9 @@ conversion utilities for the following models: main_classes/model main_classes/tokenizer main_classes/pipelines + main_classes/trainer main_classes/optimizer_schedules main_classes/processors - main_classes/trainer model_doc/auto model_doc/encoderdecoder model_doc/bert @@ -205,3 +205,4 @@ conversion utilities for the following models: model_doc/retribert model_doc/mobilebert model_doc/dpr + internal/modeling_utils diff --git a/docs/source/internal/modeling_utils.rst b/docs/source/internal/modeling_utils.rst new file mode 100644 index 0000000000..9e7fb6b11c --- /dev/null +++ b/docs/source/internal/modeling_utils.rst @@ -0,0 +1,88 @@ +Custom Layers and Utilities +--------------------------- + +This page lists all the custom layers used by the library, as well as the utility functions it provides for modeling. + +Most of those are only useful if you are studying the code of the models in the library. + + +``Pytorch custom modules`` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_utils.Conv1D + +.. autoclass:: transformers.modeling_utils.PoolerStartLogits + :members: forward + +.. autoclass:: transformers.modeling_utils.PoolerEndLogits + :members: forward + +.. autoclass:: transformers.modeling_utils.PoolerAnswerClass + :members: forward + +.. autoclass:: transformers.modeling_utils.SquadHeadOutput + +.. autoclass:: transformers.modeling_utils.SQuADHead + :members: forward + +.. autoclass:: transformers.modeling_utils.SequenceSummary + :members: forward + + +``PyTorch Helper Functions`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: transformers.apply_chunking_to_forward + +.. autofunction:: transformers.modeling_utils.find_pruneable_heads_and_indices + +.. autofunction:: transformers.modeling_utils.prune_layer + +.. autofunction:: transformers.modeling_utils.prune_conv1d_layer + +.. autofunction:: transformers.modeling_utils.prune_linear_layer + +``TensorFlow custom layers`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_tf_utils.TFConv1D + +.. autoclass:: transformers.modeling_tf_utils.TFSharedEmbeddings + :members: call + +.. autoclass:: transformers.modeling_tf_utils.TFSequenceSummary + :members: call + + +``TensorFlow loss functions`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_tf_utils.TFCausalLanguageModelingLoss + :members: + +.. autoclass:: transformers.modeling_tf_utils.TFMaskedLanguageModelingLoss + :members: + +.. autoclass:: transformers.modeling_tf_utils.TFMultipleChoiceLoss + :members: + +.. autoclass:: transformers.modeling_tf_utils.TFQuestionAnsweringLoss + :members: + +.. autoclass:: transformers.modeling_tf_utils.TFSequenceClassificationLoss + :members: + +.. autoclass:: transformers.modeling_tf_utils.TFTokenClassificationLoss + :members: + + +``TensorFlow Helper Functions`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: transformers.modeling_tf_utils.cast_bool_to_primitive + +.. autofunction:: transformers.modeling_tf_utils.get_initializer + +.. autofunction:: transformers.modeling_tf_utils.keras_serializable + +.. autofunction:: transformers.modeling_tf_utils.shape_list \ No newline at end of file diff --git a/docs/source/main_classes/model.rst b/docs/source/main_classes/model.rst index d492b7713d..bea43e94f6 100644 --- a/docs/source/main_classes/model.rst +++ b/docs/source/main_classes/model.rst @@ -1,28 +1,43 @@ Models ---------------------------------------------------- -The base class :class:`~transformers.PreTrainedModel` implements the common methods for loading/saving a model either -from a local file or directory, or from a pretrained model configuration provided by the library (downloaded from -HuggingFace's AWS S3 repository). +The base classes :class:`~transformers.PreTrainedModel` and :class:`~transformers.TFPreTrainedModel` implement the +common methods for loading/saving a model either from a local file or directory, or from a pretrained model +configuration provided by the library (downloaded from HuggingFace's AWS S3 repository). -:class:`~transformers.PreTrainedModel` also implements a few methods which are common among all the models to: +:class:`~transformers.PreTrainedModel` and :class:`~transformers.TFPreTrainedModel` also implement a few methods which +are common among all the models to: - resize the input token embeddings when new tokens are added to the vocabulary - prune the attention heads of the model. +The other methods that are common to each model are defined in :class:`~transformers.modeling_utils.ModuleUtilsMixin` +(for the PyTorch models) and :class:`~transformers.modeling_tf_utils.TFModuleUtilsMixin` (for the TensorFlow models). + + ``PreTrainedModel`` ~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.PreTrainedModel :members: -``Helper Functions`` -~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: transformers.apply_chunking_to_forward +``ModuleUtilsMixin`` +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_utils.ModuleUtilsMixin + :members: + ``TFPreTrainedModel`` ~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.TFPreTrainedModel :members: + + +``TFModelUtilsMixin`` +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_tf_utils.TFModelUtilsMixin + :members: diff --git a/setup.cfg b/setup.cfg index d1e67228d2..d8272abd10 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,5 +43,5 @@ multi_line_output = 3 use_parentheses = True [flake8] -ignore = E203, E501, E741, W503 +ignore = E203, E501, E741, W503, W605 max-line-length = 119 diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 9a7b154cef..40efef2b3a 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -100,7 +100,7 @@ class PretrainedConfig(object): method of the model. Parameters for fine-tuning tasks - - **architectures** (:obj:List[`str`], `optional`) -- Model architectures that can be used with the + - **architectures** (:obj:`List[str]`, `optional`) -- Model architectures that can be used with the model pretrained weights. - **finetuning_task** (:obj:`str`, `optional`) -- Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 8a4b8c95a7..b2ff04741d 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -18,7 +18,7 @@ import functools import logging import os import warnings -from typing import Dict +from typing import Dict, List, Optional, Union import h5py import numpy as np @@ -36,12 +36,19 @@ logger = logging.getLogger(__name__) class TFModelUtilsMixin: """ - A few utilities for `tf.keras.Model`s, to be used as a mixin. + A few utilities for :obj:`tf.keras.Model`, to be used as a mixin. """ def num_parameters(self, only_trainable: bool = False) -> int: """ - Get number of (optionally, trainable) parameters in the model. + Get the number of (optionally, trainable) parameters in the model. + + Args: + only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return only the number of trainable parameters + + Returns: + :obj:`int`: The number of parameters. """ if only_trainable: return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables)) @@ -54,16 +61,21 @@ def keras_serializable(cls): Decorate a Keras Layer class to support Keras serialization. This is done by: - 1. adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at - serialization time - 2. wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and - convert it to a config object for the actual layer initializer - 3. registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does - not need to be supplied in `custom_objects` in the call to `tf.keras.models.load_model` - :param cls: a tf.keras.layers.Layers subclass that accepts a `config` argument to its initializer (typically a - `TF*MainLayer` class in this project) - :return: the same class object, with modifications for Keras deserialization. + 1. Adding a :obj:`transformers_config` dict to the Keras config dictionary in :obj:`get_config` (called by Keras at + serialization time. + 2. Wrapping :obj:`__init__` to accept that :obj:`transformers_config` dict (passed by Keras at deserialization + time) and convert it to a config object for the actual layer initializer. + 3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does + not need to be supplied in :obj:`custom_objects` in the call to :obj:`tf.keras.models.load_model`. + + Args: + cls (a :obj:`tf.keras.layers.Layers subclass`): + Typically a :obj:`TF.MainLayer` class in this project, in general must accept a :obj:`config` argument to + its initializer. + + Returns: + The same class object, with modifications for Keras deserialization. """ initializer = cls.__init__ @@ -110,6 +122,15 @@ def keras_serializable(cls): class TFCausalLanguageModelingLoss: + """ + Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token. + + .. note:: + + Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. + + """ + def compute_loss(self, labels, logits): loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE @@ -123,6 +144,10 @@ class TFCausalLanguageModelingLoss: class TFQuestionAnsweringLoss: + """ + Loss function suitable for quetion answering. + """ + def compute_loss(self, labels, logits): loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE @@ -134,6 +159,15 @@ class TFQuestionAnsweringLoss: class TFTokenClassificationLoss: + """ + Loss function suitable for token classification. + + .. note:: + + Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. + + """ + def compute_loss(self, labels, logits): loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE @@ -141,7 +175,7 @@ class TFTokenClassificationLoss: # make sure only labels that are not equal to -100 # are taken into account as loss if tf.math.reduce_any(labels == -1).numpy() is True: - warnings.warn("Using `-1` to mask the loss for the token is depreciated. Please use `-100` instead.") + warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.") active_loss = tf.reshape(labels, (-1,)) != -1 else: active_loss = tf.reshape(labels, (-1,)) != -100 @@ -152,6 +186,10 @@ class TFTokenClassificationLoss: class TFSequenceClassificationLoss: + """ + Loss function suitable for sequence classification. + """ + def compute_loss(self, labels, logits): if shape_list(logits)[1] == 1: loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE) @@ -163,8 +201,19 @@ class TFSequenceClassificationLoss: return loss_fn(labels, logits) -TFMultipleChoiceLoss = TFSequenceClassificationLoss -TFMaskedLanguageModelingLoss = TFCausalLanguageModelingLoss +class TFMultipleChoiceLoss(TFSequenceClassificationLoss): + """Loss function suitable for multiple choice tasks.""" + + +class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss): + """ + Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens. + + .. note:: + + Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. + +""" class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): @@ -347,7 +396,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): def save_pretrained(self, save_directory): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the - `:func:`~transformers.TFPreTrainedModel.from_pretrained`` class method. + :func:`~transformers.TFPreTrainedModel.from_pretrained` class method. Arguments: save_directory (:obj:`str`): @@ -388,7 +437,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ``dbmdz/bert-base-german-cased``. - A path to a `directory` containing model weights saved using :func:`~transformersTF.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - - A path or url to a `PyTorch state_dict save file` (e.g, `./pt_model/pytorch_model.bin`). In + - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model @@ -435,7 +484,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): Whether or not to only look at local files (e.g., not try doanloading the model). use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on - our S3 (faster). + our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. kwargs (remaining dictionary of keyword arguments, `optional`): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., :obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or @@ -611,10 +660,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): class TFConv1D(tf.keras.layers.Layer): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (:obj:`int`): + The number of output features. + nx (:obj:`int`): + The number of input features. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation to use to initialize the weights. + kwargs: + Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`. + """ + def __init__(self, nf, nx, initializer_range=0.02, **kwargs): - """ TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2) - Basically works like a Linear layer but the weights are transposed - """ super().__init__(**kwargs) self.nf = nf self.nx = nx @@ -638,10 +700,25 @@ class TFConv1D(tf.keras.layers.Layer): class TFSharedEmbeddings(tf.keras.layers.Layer): - """Construct shared token embeddings. + """ + Construct shared token embeddings. + + The weights of the embedding layer is usually shared with the weights of the linear decoder when doing + language modeling. + + Args: + vocab_size (:obj:`int`): + The size of the vocabular, e.g., the number of unique tokens. + hidden_size (:obj:`int`): + The size of the embedding vectors. + initializer_range (:obj:`float`, `optional`): + The standard deviation to use when initializing the weights. If no value is provided, it will default to + :math:`1/\sqrt{hidden\_size}`. + kwargs: + Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`. """ - def __init__(self, vocab_size, hidden_size, initializer_range=None, **kwargs): + def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size @@ -667,20 +744,31 @@ class TFSharedEmbeddings(tf.keras.layers.Layer): return dict(list(base_config.items()) + list(config.items())) - def call(self, inputs, mode="embedding"): - """Get token embeddings of inputs. - Args: - inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids) - mode: string, a valid value is one of "embedding" and "linear". - Returns: - outputs: (1) If mode == "embedding", output embedding tensor, float32 with - shape [batch_size, length, embedding_size]; (2) mode == "linear", output - linear tensor, float32 with shape [batch_size, length, vocab_size]. - Raises: - ValueError: if mode is not valid. + def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor: + """ + Get token embeddings of inputs or decode final hidden state. - Shared weights logic adapted from - https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 + Args: + inputs (:obj:`tf.Tensor`): + In embedding mode, should be an int64 tensor with shape :obj:`[batch_size, length]`. + + In linear mode, should be a float tensor with shape :obj:`[batch_size, length, hidden_size]`. + mode (:obj:`str`, defaults to :obj:`"embedding"`): + A valid value is either :obj:`"embedding"` or :obj:`"linear"`, the first one indicates that the layer + should be used as an embedding layer, the second one that the layer should be used as a linear decoder. + + Returns: + :obj:`tf.Tensor`: + In embedding mode, the output is a float32 embedding tensor, with shape + :obj:`[batch_size, length, embedding_size]`. + + In linear mode, the ouput is a float32 with shape :obj:`[batch_size, length, vocab_size]`. + + Raises: + ValueError: if :obj:`mode` is not valid. + + Shared weights logic is adapted from + `here `__. """ if mode == "embedding": return self._embedding(inputs) @@ -709,22 +797,38 @@ class TFSharedEmbeddings(tf.keras.layers.Layer): class TFSequenceSummary(tf.keras.layers.Layer): - r""" Compute a single vector summary of a sequence hidden states according to various possibilities: - Args of the config class: - summary_type: - - 'last' => [default] take the last token hidden state (like XLNet) - - 'first' => take the first token hidden state (like Bert) - - 'mean' => take the mean of all tokens hidden states - - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2) - - 'attn' => Not implemented now, use multi-head attention - summary_use_proj: Add a projection after the vector extraction - summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. - summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default - summary_first_dropout: Add a dropout before the projection and activation - summary_last_dropout: Add a dropout after the projection and activation + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config (:class:`~transformers.PretrainedConfig`): + The config used by the model. Relevant arguments in the config class of the model are (refer to the + actual config class of your model for the default values it uses): + + - **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are: + + - :obj:`"last"` -- Take the last token hidden state (like XLNet) + - :obj:`"first"` -- Take the first token hidden state (like Bert) + - :obj:`"mean"` -- Take the mean of all tokens hidden states + - :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - :obj:`"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to + :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`). + - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the + output, another string or :obj:`None` will add no activation. + - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and + activation. + - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and + activation. + + initializer_range (:obj:`float`, defaults to 0.02): The standard deviation to use to initialize the weights. + kwargs: + Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`. """ - def __init__(self, config, initializer_range=0.02, **kwargs): + def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs): super().__init__(**kwargs) self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last" @@ -756,12 +860,22 @@ class TFSequenceSummary(tf.keras.layers.Layer): if self.has_last_dropout: self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout) - def call(self, inputs, training=False): - """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer. - cls_index: [optional] position of the classification token if summary_type == 'cls_index', - shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. - if summary_type == 'cls_index' and cls_index is None: - we take the last token of the sequence as classification token + def call(self, inputs, training=False) -> tf.Tensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + inputs (:obj:`Union[tf.Tensor, Tuple[tf.Tensor], List[tf.Tensor], Dict[str, tf.Tensor]]`): + One or two tensors representing: + + - **hidden_states** (:obj:`tf.Tensor` of shape :obj:`[batch_size, seq_len, hidden_size]`) -- The hidden + states of the last layer. + - **cls_index** :obj:`tf.Tensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are + optional leading dimensions of :obj:`hidden_states`. Used if :obj:`summary_type == "cls_index"` and + takes the last token of the sequence as classification token. + + Returns: + :obj:`tf.Tensor`: The summary of the sequence hidden states. """ if not isinstance(inputs, (dict, tuple, list)): hidden_states = inputs @@ -815,32 +929,47 @@ class TFSequenceSummary(tf.keras.layers.Layer): return output -def shape_list(x): - """Deal with dynamic shape in tensorflow cleanly.""" +def shape_list(x: tf.Tensor) -> List[int]: + """ + Deal with dynamic shape in tensorflow cleanly. + + Args: + x (:obj:`tf.Tensor`): The tensor we want the shape of. + + Returns: + :obj:`List[int]`: The shape of the tensor as a list. + """ static = x.shape.as_list() dynamic = tf.shape(x) return [dynamic[i] if s is None else s for i, s in enumerate(static)] -def get_initializer(initializer_range=0.02): - """Creates a `tf.initializers.truncated_normal` with the given range. +def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal: + """ + Creates a :obj:`tf.initializers.TruncatedNormal` with the given range. + Args: - initializer_range: float, initializer range for stddev. + initializer_range (`float`, defaults to 0.02): Standard deviation of the initializer range. + Returns: - TruncatedNormal initializer with stddev = `initializer_range`. + :obj:`tf.initializers.TruncatedNormal`: The truncated normal initializer. """ return tf.keras.initializers.TruncatedNormal(stddev=initializer_range) -def cast_bool_to_primitive(bool_variable, default_tensor_to_true=False): - """Function arguments can be inserted as boolean tensor - and bool variables to cope with keras serialization - we need to cast `output_attentions` to correct bool - if it is a tensor +def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor_to_true=False) -> bool: + """ + Function arguments can be inserted as boolean tensor and bool variables to cope with Keras serialization we need to + cast the bool argumnets (like :obj:`output_attentions` for instance) to correct boolean if it is a tensor. Args: - default_tensor_to_true: bool, if tensor should default to True - in case tensor has no numpy attribute + bool_variable (:obj:`Union[tf.Tensor, bool]`): + The variable to convert to a boolean. + default_tensor_to_true (:obj:`bool`, `optional`, defaults to `False`): + The default value to use in case the tensor has no numpy attribute. + + Returns: + :obj:`bool`: The converted value. """ # if bool variable is tensor and has numpy value if tf.is_tensor(bool_variable): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1850589cf6..bd33f7a7a3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -19,7 +19,7 @@ import logging import os import re from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import torch from torch import Tensor, device, dtype, nn @@ -38,6 +38,7 @@ from .file_utils import ( hf_bucket_url, is_remote_url, is_torch_tpu_available, + replace_return_docstrings, ) from .generation_utils import GenerationMixin @@ -61,8 +62,20 @@ except ImportError: def find_pruneable_heads_and_indices( - heads: List, n_heads: int, head_size: int, already_pruned_heads: set -) -> Tuple[set, "torch.LongTensor"]: + heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int] +) -> Tuple[Set[int], torch.LongTensor]: + """ + Finds the heads and their indices taking :obj:`already_pruned_heads` into account. + + Args: + heads (:obj:`List[int]`): List of the indices of heads to prune. + n_heads (:obj:`int`): The number of heads in the model. + head_size (:obj:`int`): The size of each head. + already_pruned_heads (:obj:`Set[int]`): A set of already pruned heads. + + Returns: + :obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices. + """ mask = torch.ones(n_heads, head_size) heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads for head in heads: @@ -76,12 +89,19 @@ def find_pruneable_heads_and_indices( class ModuleUtilsMixin: """ - A few utilities for torch.nn.Modules, to be used as a mixin. + A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin. """ def num_parameters(self, only_trainable: bool = False) -> int: """ - Get number of (optionally, trainable) parameters in the module. + Get the number of (optionally, trainable) parameters in the model. + + Args: + only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return only the number of trainable parameters + + Returns: + :obj:`int`: The number of parameters. """ params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters() return sum(p.numel() for p in params) @@ -113,8 +133,11 @@ class ModuleUtilsMixin: return None def add_memory_hooks(self): - """ Add a memory hook before and after each sub-module forward pass to record increase in memory consumption. - Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()` + """ + Add a memory hook before and after each sub-module forward pass to record increase in memory consumption. + + Increase in memory consumption is stored in a :obj:`mem_rss_diff` attribute for each module and can be reset to + zero with :obj:`model.reset_memory_hooks_state()`. """ for module in self.modules(): module.register_forward_pre_hook(self._hook_rss_memory_pre_forward) @@ -122,6 +145,10 @@ class ModuleUtilsMixin: self.reset_memory_hooks_state() def reset_memory_hooks_state(self): + """ + Reset the :obj:`mem_rss_diff` attribute of each module (see + :func:`~transformers.modeling_utils.ModuleUtilsMixin.add_memory_hooks`). + """ for module in self.modules(): module.mem_rss_diff = 0 module.mem_rss_post_forward = 0 @@ -130,7 +157,10 @@ class ModuleUtilsMixin: @property def device(self) -> device: """ - Get torch.device from module, assuming that the whole module has one device. + The device on which the module is (assuming that all the module parameters are on the same device). + + Returns: + :obj:`torch.device` The device of the module. """ try: return next(self.parameters()).device @@ -148,7 +178,10 @@ class ModuleUtilsMixin: @property def dtype(self) -> dtype: """ - Get torch.dtype from module, assuming that the whole module has one dtype. + The dtype of the module (assuming that all the module parameters have the same dtype). + + Returns: + :obj:`torch.dtype` The dtype of the module. """ try: return next(self.parameters()).dtype @@ -164,7 +197,15 @@ class ModuleUtilsMixin: return first_tuple[1].dtype def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: - """type: torch.Tensor -> torch.Tensor""" + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (:obj:`torch.Tensor`): An attention mask. + + Returns: + :obj:`torch.Tensor`: The inverted attention mask. + """ if encoder_attention_mask.dim() == 3: encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] if encoder_attention_mask.dim() == 2: @@ -189,16 +230,20 @@ class ModuleUtilsMixin: return encoder_extended_attention_mask - def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor: - """Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored. + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: - attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to - input_shape: tuple, shape of input_ids - device: torch.Device, usually self.device + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. Returns: - torch.Tensor with dtype of attention_mask.dtype + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. """ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. @@ -233,17 +278,23 @@ class ModuleUtilsMixin: extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask - def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor: + def get_head_mask( + self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False + ) -> Tensor: """ - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - attention_probs has shape bsz x n_heads x N x N - Arguments: - head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads] - num_hidden_layers: int + Prepare the head mask if needed. + + Args: + head_mask (:obj:`torch.Tensor` with shape :obj:`[num_heads]` or :obj:`[num_hidden_layers x num_heads]`, `optional`): + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). + num_hidden_layers (:obj:`int`): + The number of hidden layers in the model. + is_attention_chunked: (:obj:`bool`, `optional, defaults to :obj:`False`): + Whether or not the attentions scores are computed by chunks or not. + Returns: - Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - or list with [None] for each layer + :obj:`torch.Tensor` with shape :obj:`[num_hidden_layers x batch x num_heads x seq_length x seq_length]` + or list with :obj:`[None]` for each layer. """ if head_mask is not None: head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) @@ -557,7 +608,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ``dbmdz/bert-base-german-cased``. - A path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - - A path or url to a `tensorflow index checkpoint file` (e.g, `./tf_model/model.ckpt.index`). In + - A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. @@ -610,7 +661,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): Whether or not to only look at local files (e.g., not try doanloading the model). use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on - our S3 (faster). + our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. kwargs (remaining dictionary of keyword arguments, `optional`): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., :obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or @@ -870,10 +921,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (:obj:`int`): The number of output features. + nx (:obj:`int`): The number of input features. + """ + def __init__(self, nf, nx): - """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2) - Basically works like a Linear layer but the weights are transposed - """ super().__init__() self.nf = nf w = torch.empty(nx, nf) @@ -889,17 +947,31 @@ class Conv1D(nn.Module): class PoolerStartLogits(nn.Module): - """ Compute SQuAD start_logits from sequence hidden states. """ + """ + Compute SQuAD start logits from sequence hidden states. - def __init__(self, config): + Args: + config (:class:`~transformers.PretrainedConfig`): + The config used by the model, will be used to grab the :obj:`hidden_size` of the model. + """ + + def __init__(self, config: PretrainedConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, 1) - def forward(self, hidden_states, p_mask=None): - """ Args: - **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)` - invalid position mask such as query and special symbols (PAD, SEP, CLS) + def forward( + self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + """ + Args: + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token should be masked. + + Returns: + :obj:`torch.FloatTensor`: The start logits for SQuAD. """ x = self.dense(hidden_states).squeeze(-1) @@ -913,28 +985,48 @@ class PoolerStartLogits(nn.Module): class PoolerEndLogits(nn.Module): - """ Compute SQuAD end_logits from sequence hidden states and start token hidden state. + """ + Compute SQuAD end logits from sequence hidden states. + + Args: + config (:class:`~transformers.PretrainedConfig`): + The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the + :obj:`layer_norm_eps` to use. """ - def __init__(self, config): + def __init__(self, config: PretrainedConfig): super().__init__() self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) self.activation = nn.Tanh() self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dense_1 = nn.Linear(config.hidden_size, 1) - def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None): - """ Args: - One of ``start_states``, ``start_positions`` should be not None. - If both are set, ``start_positions`` overrides ``start_states``. - - **start_states**: ``torch.LongTensor`` of shape identical to hidden_states - hidden states of the first tokens for the labeled span. - **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` - position of the first token for the labeled span: - **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` - Mask of invalid position such as query and special symbols (PAD, SEP, CLS) + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`): + The hidden states of the first tokens for the labeled span. + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + The position of the first token for the labeled span. + p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token should be masked. + + .. note:: + + One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set, + ``start_positions`` overrides ``start_states``. + + Returns: + :obj:`torch.FloatTensor`: The end logits for SQuAD. """ assert ( start_states is not None or start_positions is not None @@ -960,7 +1052,13 @@ class PoolerEndLogits(nn.Module): class PoolerAnswerClass(nn.Module): - """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """ + """ + Compute SQuAD 2.0 answer class from classification and start tokens hidden states. + + Args: + config (:class:`~transformers.PretrainedConfig`): + The config used by the model, will be used to grab the :obj:`hidden_size` of the model. + """ def __init__(self, config): super().__init__() @@ -968,23 +1066,33 @@ class PoolerAnswerClass(nn.Module): self.activation = nn.Tanh() self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) - def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None): + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: """ Args: - One of ``start_states``, ``start_positions`` should be not None. - If both are set, ``start_positions`` overrides ``start_states``. + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`): + The hidden states of the first tokens for the labeled span. + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + The position of the first token for the labeled span. + cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token. - **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``. - hidden states of the first tokens for the labeled span. - **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` - position of the first token for the labeled span. - **cls_index**: torch.LongTensor of shape ``(batch_size,)`` - position of the CLS token. If None, take the last token. + .. note:: - note(Original repo): - no dependency on end_feature so that we can obtain one single `cls_logits` - for each sample + One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set, + ``start_positions`` overrides ``start_states``. + + Returns: + :obj:`torch.FloatTensor`: The SQuAD 2.0 answer class. """ + # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample. hsz = hidden_states.shape[-1] assert ( start_states is not None or start_positions is not None @@ -1009,7 +1117,7 @@ class PoolerAnswerClass(nn.Module): @dataclass class SquadHeadOutput(ModelOutput): """ - Base class for outputs of question answering models using a :obj:`SquadHead`. + Base class for outputs of question answering models using a :class:`~transformers.modeling_utils.SQuADHead`. Args: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided): @@ -1036,44 +1144,13 @@ class SquadHeadOutput(ModelOutput): class SQuADHead(nn.Module): - r""" A SQuAD head inspired by XLNet. + r""" + A SQuAD head inspired by XLNet. - Parameters: - config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model. - - Inputs: - **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)`` - hidden states of sequence tokens - **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` - position of the first token for the labeled span. - **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` - position of the last token for the labeled span. - **cls_index**: torch.LongTensor of shape ``(batch_size,)`` - position of the CLS token. If None, take the last token. - **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)`` - Whether the question has a possible answer in the paragraph or not. - **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` - Mask of invalid position such as query and special symbols (PAD, SEP, CLS) - 1.0 means token should be masked. - - Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: - **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: - Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. - **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) - ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)`` - Log probabilities for the top config.start_n_top start token possibilities (beam-search). - **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) - ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)`` - Indices for the top config.start_n_top start token possibilities (beam-search). - **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) - ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` - Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). - **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) - ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` - Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). - **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) - ``torch.FloatTensor`` of shape ``(batch_size,)`` - Log probabilities for the ``is_impossible`` label of the answers. + Args: + config (:class:`~transformers.PretrainedConfig`): + The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the + :obj:`layer_norm_eps` to use. """ def __init__(self, config): @@ -1085,16 +1162,37 @@ class SQuADHead(nn.Module): self.end_logits = PoolerEndLogits(config) self.answer_class = PoolerAnswerClass(config) + @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig) def forward( self, - hidden_states, - start_positions=None, - end_positions=None, - cls_index=None, - is_impossible=None, - p_mask=None, - return_tuple=False, - ): + hidden_states: torch.FloatTensor, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + is_impossible: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + return_tuple: bool = False, + ) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]: + """ + Args: + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): + Final hidden states of the model on the sequence tokens. + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Positions of the first token for the labeled span. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Positions of the last token for the labeled span. + cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token. + is_impossible (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Whether the question has a possible answer in the paragraph or not. + p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). + 1.0 means token should be masked. + return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return a plain tuple instead of a :class:`~transformers.file_utils.ModelOuput`. + + Returns: + """ start_logits = self.start_logits(hidden_states, p_mask=p_mask) if start_positions is not None and end_positions is not None: @@ -1163,19 +1261,31 @@ class SQuADHead(nn.Module): class SequenceSummary(nn.Module): - r""" Compute a single vector summary of a sequence hidden states according to various possibilities: - Args of the config class: - summary_type: - - 'last' => [default] take the last token hidden state (like XLNet) - - 'first' => take the first token hidden state (like Bert) - - 'mean' => take the mean of all tokens hidden states - - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2) - - 'attn' => Not implemented now, use multi-head attention - summary_use_proj: Add a projection after the vector extraction - summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. - summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default - summary_first_dropout: Add a dropout before the projection and activation - summary_last_dropout: Add a dropout after the projection and activation + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config (:class:`~transformers.PretrainedConfig`): + The config used by the model. Relevant arguments in the config class of the model are (refer to the + actual config class of your model for the default values it uses): + + - **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are: + + - :obj:`"last"` -- Take the last token hidden state (like XLNet) + - :obj:`"first"` -- Take the first token hidden state (like Bert) + - :obj:`"mean"` -- Take the mean of all tokens hidden states + - :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - :obj:`"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to + :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`). + - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the + output, another string or :obj:`None` will add no activation. + - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and + activation. + - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and + activation. """ def __init__(self, config: PretrainedConfig): @@ -1207,12 +1317,21 @@ class SequenceSummary(nn.Module): if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout) - def forward(self, hidden_states, cls_index=None): - """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer. - cls_index: [optional] position of the classification token if summary_type == 'cls_index', - shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. - if summary_type == 'cls_index' and cls_index is None: - we take the last token of the sequence as classification token + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (:obj:`torch.LongTensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`): + Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification + token. + + Returns: + :obj:`torch.FloatTensor`: The summary of the sequence hidden states. """ if self.summary_type == "last": output = hidden_states[:, -1] @@ -1239,10 +1358,19 @@ class SequenceSummary(nn.Module): return output -def prune_linear_layer(layer, index, dim=0): - """ Prune a linear layer (a model parameters) to keep only entries in index. - Return the pruned layer as a new layer with requires_grad=True. - Used to remove heads. +def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear: + """ + Prune a linear layer to keep only entries in index. + + Used to remove heads. + + Args: + layer (:obj:`torch.nn.Linear`): The layer to prune. + index (:obj:`torch.LongTensor`): The indices to keep in the layer. + dim (:obj:`int`, `optional`, defaults to 0): The dimension on which to keep the indices. + + Returns: + :obj:`torch.nn.Linear`: The pruned layer as a new layer with :obj:`requires_grad=True`. """ index = index.to(layer.weight.device) W = layer.weight.index_select(dim, index).clone().detach() @@ -1264,11 +1392,20 @@ def prune_linear_layer(layer, index, dim=0): return new_layer -def prune_conv1d_layer(layer, index, dim=1): - """ Prune a Conv1D layer (a model parameters) to keep only entries in index. - A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed. - Return the pruned layer as a new layer with requires_grad=True. - Used to remove heads. +def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D: + """ + Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights + are transposed. + + Used to remove heads. + + Args: + layer (:class:`~transformers.modeling_utils.Conv1D`): The layer to prune. + index (:obj:`torch.LongTensor`): The indices to keep in the layer. + dim (:obj:`int`, `optional`, defaults to 1): The dimension on which to keep the indices. + + Returns: + :class:`~transformers.modeling_utils.Conv1D`: The pruned layer as a new layer with :obj:`requires_grad=True`. """ index = index.to(layer.weight.device) W = layer.weight.index_select(dim, index).clone().detach() @@ -1288,10 +1425,22 @@ def prune_conv1d_layer(layer, index, dim=1): return new_layer -def prune_layer(layer, index, dim=None): - """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index. - Return the pruned layer as a new layer with requires_grad=True. - Used to remove heads. +def prune_layer( + layer: Union[torch.nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None +) -> Union[torch.nn.Linear, Conv1D]: + """ + Prune a Conv1D or linear layer to keep only entries in index. + + Used to remove heads. + + Args: + layer (:obj:`Union[torch.nn.Linear, Conv1D]`): The layer to prune. + index (:obj:`torch.LongTensor`): The indices to keep in the layer. + dim (:obj:`int`, `optional`): The dimension on which to keep the indices. + + Returns: + :obj:`torch.nn.Linear` or :class:`~transformers.modeling_utils.Conv1D`: + The pruned layer as a new layer with :obj:`requires_grad=True`. """ if isinstance(layer, nn.Linear): return prune_linear_layer(layer, index, dim=0 if dim is None else dim)