Model utils doc (#6005)

* Document TF modeling utils

* Document all model utils
This commit is contained in:
Sylvain Gugger
2020-07-24 09:16:28 -04:00
committed by GitHub
parent a540405213
commit 3b44aa935a
7 changed files with 601 additions and 219 deletions

View File

@@ -19,7 +19,7 @@ import logging
import os
import re
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from torch import Tensor, device, dtype, nn
@@ -38,6 +38,7 @@ from .file_utils import (
hf_bucket_url,
is_remote_url,
is_torch_tpu_available,
replace_return_docstrings,
)
from .generation_utils import GenerationMixin
@@ -61,8 +62,20 @@ except ImportError:
def find_pruneable_heads_and_indices(
heads: List, n_heads: int, head_size: int, already_pruned_heads: set
) -> Tuple[set, "torch.LongTensor"]:
heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
) -> Tuple[Set[int], torch.LongTensor]:
"""
Finds the heads and their indices taking :obj:`already_pruned_heads` into account.
Args:
heads (:obj:`List[int]`): List of the indices of heads to prune.
n_heads (:obj:`int`): The number of heads in the model.
head_size (:obj:`int`): The size of each head.
already_pruned_heads (:obj:`Set[int]`): A set of already pruned heads.
Returns:
:obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
"""
mask = torch.ones(n_heads, head_size)
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
for head in heads:
@@ -76,12 +89,19 @@ def find_pruneable_heads_and_indices(
class ModuleUtilsMixin:
"""
A few utilities for torch.nn.Modules, to be used as a mixin.
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
"""
def num_parameters(self, only_trainable: bool = False) -> int:
"""
Get number of (optionally, trainable) parameters in the module.
Get the number of (optionally, trainable) parameters in the model.
Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of trainable parameters
Returns:
:obj:`int`: The number of parameters.
"""
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
@@ -113,8 +133,11 @@ class ModuleUtilsMixin:
return None
def add_memory_hooks(self):
""" Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
"""
Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
Increase in memory consumption is stored in a :obj:`mem_rss_diff` attribute for each module and can be reset to
zero with :obj:`model.reset_memory_hooks_state()`.
"""
for module in self.modules():
module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
@@ -122,6 +145,10 @@ class ModuleUtilsMixin:
self.reset_memory_hooks_state()
def reset_memory_hooks_state(self):
"""
Reset the :obj:`mem_rss_diff` attribute of each module (see
:func:`~transformers.modeling_utils.ModuleUtilsMixin.add_memory_hooks`).
"""
for module in self.modules():
module.mem_rss_diff = 0
module.mem_rss_post_forward = 0
@@ -130,7 +157,10 @@ class ModuleUtilsMixin:
@property
def device(self) -> device:
"""
Get torch.device from module, assuming that the whole module has one device.
The device on which the module is (assuming that all the module parameters are on the same device).
Returns:
:obj:`torch.device` The device of the module.
"""
try:
return next(self.parameters()).device
@@ -148,7 +178,10 @@ class ModuleUtilsMixin:
@property
def dtype(self) -> dtype:
"""
Get torch.dtype from module, assuming that the whole module has one dtype.
The dtype of the module (assuming that all the module parameters have the same dtype).
Returns:
:obj:`torch.dtype` The dtype of the module.
"""
try:
return next(self.parameters()).dtype
@@ -164,7 +197,15 @@ class ModuleUtilsMixin:
return first_tuple[1].dtype
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
"""type: torch.Tensor -> torch.Tensor"""
"""
Invert an attention mask (e.g., switches 0. and 1.).
Args:
encoder_attention_mask (:obj:`torch.Tensor`): An attention mask.
Returns:
:obj:`torch.Tensor`: The inverted attention mask.
"""
if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.dim() == 2:
@@ -189,16 +230,20 @@ class ModuleUtilsMixin:
return encoder_extended_attention_mask
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor:
"""Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to
input_shape: tuple, shape of input_ids
device: torch.Device, usually self.device
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
torch.Tensor with dtype of attention_mask.dtype
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
@@ -233,17 +278,23 @@ class ModuleUtilsMixin:
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
def get_head_mask(
self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
) -> Tensor:
"""
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
attention_probs has shape bsz x n_heads x N x N
Arguments:
head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads]
num_hidden_layers: int
Prepare the head mask if needed.
Args:
head_mask (:obj:`torch.Tensor` with shape :obj:`[num_heads]` or :obj:`[num_hidden_layers x num_heads]`, `optional`):
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
num_hidden_layers (:obj:`int`):
The number of hidden layers in the model.
is_attention_chunked: (:obj:`bool`, `optional, defaults to :obj:`False`):
Whether or not the attentions scores are computed by chunks or not.
Returns:
Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
or list with [None] for each layer
:obj:`torch.Tensor` with shape :obj:`[num_hidden_layers x batch x num_heads x seq_length x seq_length]`
or list with :obj:`[None]` for each layer.
"""
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
@@ -557,7 +608,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `tensorflow index checkpoint file` (e.g, `./tf_model/model.ckpt.index`). In
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
@@ -610,7 +661,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
Whether or not to only look at local files (e.g., not try doanloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster).
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or
@@ -870,10 +921,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (:obj:`int`): The number of output features.
nx (:obj:`int`): The number of input features.
"""
def __init__(self, nf, nx):
""" Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
Basically works like a Linear layer but the weights are transposed
"""
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
@@ -889,17 +947,31 @@ class Conv1D(nn.Module):
class PoolerStartLogits(nn.Module):
""" Compute SQuAD start_logits from sequence hidden states. """
"""
Compute SQuAD start logits from sequence hidden states.
def __init__(self, config):
Args:
config (:class:`~transformers.PretrainedConfig`):
The config used by the model, will be used to grab the :obj:`hidden_size` of the model.
"""
def __init__(self, config: PretrainedConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, 1)
def forward(self, hidden_states, p_mask=None):
""" Args:
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
invalid position mask such as query and special symbols (PAD, SEP, CLS)
def forward(
self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
"""
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
The final hidden states of the model.
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1.0 means token should be masked.
Returns:
:obj:`torch.FloatTensor`: The start logits for SQuAD.
"""
x = self.dense(hidden_states).squeeze(-1)
@@ -913,28 +985,48 @@ class PoolerStartLogits(nn.Module):
class PoolerEndLogits(nn.Module):
""" Compute SQuAD end_logits from sequence hidden states and start token hidden state.
"""
Compute SQuAD end logits from sequence hidden states.
Args:
config (:class:`~transformers.PretrainedConfig`):
The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the
:obj:`layer_norm_eps` to use.
"""
def __init__(self, config):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.activation = nn.Tanh()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dense_1 = nn.Linear(config.hidden_size, 1)
def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
""" Args:
One of ``start_states``, ``start_positions`` should be not None.
If both are set, ``start_positions`` overrides ``start_states``.
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states
hidden states of the first tokens for the labeled span.
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
position of the first token for the labeled span:
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
def forward(
self,
hidden_states: torch.FloatTensor,
start_states: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
p_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
The final hidden states of the model.
start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`):
The hidden states of the first tokens for the labeled span.
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
The position of the first token for the labeled span.
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1.0 means token should be masked.
.. note::
One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set,
``start_positions`` overrides ``start_states``.
Returns:
:obj:`torch.FloatTensor`: The end logits for SQuAD.
"""
assert (
start_states is not None or start_positions is not None
@@ -960,7 +1052,13 @@ class PoolerEndLogits(nn.Module):
class PoolerAnswerClass(nn.Module):
""" Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
"""
Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
Args:
config (:class:`~transformers.PretrainedConfig`):
The config used by the model, will be used to grab the :obj:`hidden_size` of the model.
"""
def __init__(self, config):
super().__init__()
@@ -968,23 +1066,33 @@ class PoolerAnswerClass(nn.Module):
self.activation = nn.Tanh()
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
def forward(
self,
hidden_states: torch.FloatTensor,
start_states: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
cls_index: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
"""
Args:
One of ``start_states``, ``start_positions`` should be not None.
If both are set, ``start_positions`` overrides ``start_states``.
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
The final hidden states of the model.
start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`):
The hidden states of the first tokens for the labeled span.
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
The position of the first token for the labeled span.
cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
**start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
hidden states of the first tokens for the labeled span.
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
position of the first token for the labeled span.
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
position of the CLS token. If None, take the last token.
.. note::
note(Original repo):
no dependency on end_feature so that we can obtain one single `cls_logits`
for each sample
One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set,
``start_positions`` overrides ``start_states``.
Returns:
:obj:`torch.FloatTensor`: The SQuAD 2.0 answer class.
"""
# No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
hsz = hidden_states.shape[-1]
assert (
start_states is not None or start_positions is not None
@@ -1009,7 +1117,7 @@ class PoolerAnswerClass(nn.Module):
@dataclass
class SquadHeadOutput(ModelOutput):
"""
Base class for outputs of question answering models using a :obj:`SquadHead`.
Base class for outputs of question answering models using a :class:`~transformers.modeling_utils.SQuADHead`.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
@@ -1036,44 +1144,13 @@ class SquadHeadOutput(ModelOutput):
class SQuADHead(nn.Module):
r""" A SQuAD head inspired by XLNet.
r"""
A SQuAD head inspired by XLNet.
Parameters:
config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
Inputs:
**hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
hidden states of sequence tokens
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
position of the first token for the labeled span.
**end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
position of the last token for the labeled span.
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
position of the CLS token. If None, take the last token.
**is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
Whether the question has a possible answer in the paragraph or not.
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
1.0 means token should be masked.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
**start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
**start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
Indices for the top config.start_n_top start token possibilities (beam-search).
**end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
**end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
**cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size,)``
Log probabilities for the ``is_impossible`` label of the answers.
Args:
config (:class:`~transformers.PretrainedConfig`):
The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the
:obj:`layer_norm_eps` to use.
"""
def __init__(self, config):
@@ -1085,16 +1162,37 @@ class SQuADHead(nn.Module):
self.end_logits = PoolerEndLogits(config)
self.answer_class = PoolerAnswerClass(config)
@replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig)
def forward(
self,
hidden_states,
start_positions=None,
end_positions=None,
cls_index=None,
is_impossible=None,
p_mask=None,
return_tuple=False,
):
hidden_states: torch.FloatTensor,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
cls_index: Optional[torch.LongTensor] = None,
is_impossible: Optional[torch.LongTensor] = None,
p_mask: Optional[torch.FloatTensor] = None,
return_tuple: bool = False,
) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]:
"""
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
Final hidden states of the model on the sequence tokens.
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Positions of the first token for the labeled span.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Positions of the last token for the labeled span.
cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
is_impossible (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Whether the question has a possible answer in the paragraph or not.
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1.0 means token should be masked.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return a plain tuple instead of a :class:`~transformers.file_utils.ModelOuput`.
Returns:
"""
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
if start_positions is not None and end_positions is not None:
@@ -1163,19 +1261,31 @@ class SQuADHead(nn.Module):
class SequenceSummary(nn.Module):
r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
Args of the config class:
summary_type:
- 'last' => [default] take the last token hidden state (like XLNet)
- 'first' => take the first token hidden state (like Bert)
- 'mean' => take the mean of all tokens hidden states
- 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
summary_first_dropout: Add a dropout before the projection and activation
summary_last_dropout: Add a dropout after the projection and activation
r"""
Compute a single vector summary of a sequence hidden states.
Args:
config (:class:`~transformers.PretrainedConfig`):
The config used by the model. Relevant arguments in the config class of the model are (refer to the
actual config class of your model for the default values it uses):
- **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are:
- :obj:`"last"` -- Take the last token hidden state (like XLNet)
- :obj:`"first"` -- Take the first token hidden state (like Bert)
- :obj:`"mean"` -- Take the mean of all tokens hidden states
- :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
- :obj:`"attn"` -- Not implemented now, use multi-head attention
- **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction.
- **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to
:obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`).
- **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the
output, another string or :obj:`None` will add no activation.
- **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and
activation.
- **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and
activation.
"""
def __init__(self, config: PretrainedConfig):
@@ -1207,12 +1317,21 @@ class SequenceSummary(nn.Module):
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
self.last_dropout = nn.Dropout(config.summary_last_dropout)
def forward(self, hidden_states, cls_index=None):
""" hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
if summary_type == 'cls_index' and cls_index is None:
we take the last token of the sequence as classification token
def forward(
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
) -> torch.FloatTensor:
"""
Compute a single vector summary of a sequence hidden states.
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`[batch_size, seq_len, hidden_size]`):
The hidden states of the last layer.
cls_index (:obj:`torch.LongTensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`):
Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification
token.
Returns:
:obj:`torch.FloatTensor`: The summary of the sequence hidden states.
"""
if self.summary_type == "last":
output = hidden_states[:, -1]
@@ -1239,10 +1358,19 @@ class SequenceSummary(nn.Module):
return output
def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear:
"""
Prune a linear layer to keep only entries in index.
Used to remove heads.
Args:
layer (:obj:`torch.nn.Linear`): The layer to prune.
index (:obj:`torch.LongTensor`): The indices to keep in the layer.
dim (:obj:`int`, `optional`, defaults to 0): The dimension on which to keep the indices.
Returns:
:obj:`torch.nn.Linear`: The pruned layer as a new layer with :obj:`requires_grad=True`.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
@@ -1264,11 +1392,20 @@ def prune_linear_layer(layer, index, dim=0):
return new_layer
def prune_conv1d_layer(layer, index, dim=1):
""" Prune a Conv1D layer (a model parameters) to keep only entries in index.
A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
"""
Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
are transposed.
Used to remove heads.
Args:
layer (:class:`~transformers.modeling_utils.Conv1D`): The layer to prune.
index (:obj:`torch.LongTensor`): The indices to keep in the layer.
dim (:obj:`int`, `optional`, defaults to 1): The dimension on which to keep the indices.
Returns:
:class:`~transformers.modeling_utils.Conv1D`: The pruned layer as a new layer with :obj:`requires_grad=True`.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
@@ -1288,10 +1425,22 @@ def prune_conv1d_layer(layer, index, dim=1):
return new_layer
def prune_layer(layer, index, dim=None):
""" Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
def prune_layer(
layer: Union[torch.nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
) -> Union[torch.nn.Linear, Conv1D]:
"""
Prune a Conv1D or linear layer to keep only entries in index.
Used to remove heads.
Args:
layer (:obj:`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
index (:obj:`torch.LongTensor`): The indices to keep in the layer.
dim (:obj:`int`, `optional`): The dimension on which to keep the indices.
Returns:
:obj:`torch.nn.Linear` or :class:`~transformers.modeling_utils.Conv1D`:
The pruned layer as a new layer with :obj:`requires_grad=True`.
"""
if isinstance(layer, nn.Linear):
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)