Model utils doc (#6005)
* Document TF modeling utils * Document all model utils
This commit is contained in:
@@ -19,7 +19,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device, dtype, nn
|
||||
@@ -38,6 +38,7 @@ from .file_utils import (
|
||||
hf_bucket_url,
|
||||
is_remote_url,
|
||||
is_torch_tpu_available,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .generation_utils import GenerationMixin
|
||||
|
||||
@@ -61,8 +62,20 @@ except ImportError:
|
||||
|
||||
|
||||
def find_pruneable_heads_and_indices(
|
||||
heads: List, n_heads: int, head_size: int, already_pruned_heads: set
|
||||
) -> Tuple[set, "torch.LongTensor"]:
|
||||
heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
|
||||
) -> Tuple[Set[int], torch.LongTensor]:
|
||||
"""
|
||||
Finds the heads and their indices taking :obj:`already_pruned_heads` into account.
|
||||
|
||||
Args:
|
||||
heads (:obj:`List[int]`): List of the indices of heads to prune.
|
||||
n_heads (:obj:`int`): The number of heads in the model.
|
||||
head_size (:obj:`int`): The size of each head.
|
||||
already_pruned_heads (:obj:`Set[int]`): A set of already pruned heads.
|
||||
|
||||
Returns:
|
||||
:obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
|
||||
"""
|
||||
mask = torch.ones(n_heads, head_size)
|
||||
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
|
||||
for head in heads:
|
||||
@@ -76,12 +89,19 @@ def find_pruneable_heads_and_indices(
|
||||
|
||||
class ModuleUtilsMixin:
|
||||
"""
|
||||
A few utilities for torch.nn.Modules, to be used as a mixin.
|
||||
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
|
||||
"""
|
||||
|
||||
def num_parameters(self, only_trainable: bool = False) -> int:
|
||||
"""
|
||||
Get number of (optionally, trainable) parameters in the module.
|
||||
Get the number of (optionally, trainable) parameters in the model.
|
||||
|
||||
Args:
|
||||
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to return only the number of trainable parameters
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The number of parameters.
|
||||
"""
|
||||
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
|
||||
return sum(p.numel() for p in params)
|
||||
@@ -113,8 +133,11 @@ class ModuleUtilsMixin:
|
||||
return None
|
||||
|
||||
def add_memory_hooks(self):
|
||||
""" Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
|
||||
Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
|
||||
"""
|
||||
Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
|
||||
|
||||
Increase in memory consumption is stored in a :obj:`mem_rss_diff` attribute for each module and can be reset to
|
||||
zero with :obj:`model.reset_memory_hooks_state()`.
|
||||
"""
|
||||
for module in self.modules():
|
||||
module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
|
||||
@@ -122,6 +145,10 @@ class ModuleUtilsMixin:
|
||||
self.reset_memory_hooks_state()
|
||||
|
||||
def reset_memory_hooks_state(self):
|
||||
"""
|
||||
Reset the :obj:`mem_rss_diff` attribute of each module (see
|
||||
:func:`~transformers.modeling_utils.ModuleUtilsMixin.add_memory_hooks`).
|
||||
"""
|
||||
for module in self.modules():
|
||||
module.mem_rss_diff = 0
|
||||
module.mem_rss_post_forward = 0
|
||||
@@ -130,7 +157,10 @@ class ModuleUtilsMixin:
|
||||
@property
|
||||
def device(self) -> device:
|
||||
"""
|
||||
Get torch.device from module, assuming that the whole module has one device.
|
||||
The device on which the module is (assuming that all the module parameters are on the same device).
|
||||
|
||||
Returns:
|
||||
:obj:`torch.device` The device of the module.
|
||||
"""
|
||||
try:
|
||||
return next(self.parameters()).device
|
||||
@@ -148,7 +178,10 @@ class ModuleUtilsMixin:
|
||||
@property
|
||||
def dtype(self) -> dtype:
|
||||
"""
|
||||
Get torch.dtype from module, assuming that the whole module has one dtype.
|
||||
The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||
|
||||
Returns:
|
||||
:obj:`torch.dtype` The dtype of the module.
|
||||
"""
|
||||
try:
|
||||
return next(self.parameters()).dtype
|
||||
@@ -164,7 +197,15 @@ class ModuleUtilsMixin:
|
||||
return first_tuple[1].dtype
|
||||
|
||||
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
|
||||
"""type: torch.Tensor -> torch.Tensor"""
|
||||
"""
|
||||
Invert an attention mask (e.g., switches 0. and 1.).
|
||||
|
||||
Args:
|
||||
encoder_attention_mask (:obj:`torch.Tensor`): An attention mask.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.Tensor`: The inverted attention mask.
|
||||
"""
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.dim() == 2:
|
||||
@@ -189,16 +230,20 @@ class ModuleUtilsMixin:
|
||||
|
||||
return encoder_extended_attention_mask
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor:
|
||||
"""Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
|
||||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
|
||||
"""
|
||||
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||
|
||||
Arguments:
|
||||
attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to
|
||||
input_shape: tuple, shape of input_ids
|
||||
device: torch.Device, usually self.device
|
||||
attention_mask (:obj:`torch.Tensor`):
|
||||
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
||||
input_shape (:obj:`Tuple[int]`):
|
||||
The shape of the input to the model.
|
||||
device: (:obj:`torch.device`):
|
||||
The device of the input to the model.
|
||||
|
||||
Returns:
|
||||
torch.Tensor with dtype of attention_mask.dtype
|
||||
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
||||
"""
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
@@ -233,17 +278,23 @@ class ModuleUtilsMixin:
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
return extended_attention_mask
|
||||
|
||||
def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
|
||||
def get_head_mask(
|
||||
self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
|
||||
) -> Tensor:
|
||||
"""
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
attention_probs has shape bsz x n_heads x N x N
|
||||
Arguments:
|
||||
head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
num_hidden_layers: int
|
||||
Prepare the head mask if needed.
|
||||
|
||||
Args:
|
||||
head_mask (:obj:`torch.Tensor` with shape :obj:`[num_heads]` or :obj:`[num_hidden_layers x num_heads]`, `optional`):
|
||||
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
|
||||
num_hidden_layers (:obj:`int`):
|
||||
The number of hidden layers in the model.
|
||||
is_attention_chunked: (:obj:`bool`, `optional, defaults to :obj:`False`):
|
||||
Whether or not the attentions scores are computed by chunks or not.
|
||||
|
||||
Returns:
|
||||
Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
or list with [None] for each layer
|
||||
:obj:`torch.Tensor` with shape :obj:`[num_hidden_layers x batch x num_heads x seq_length x seq_length]`
|
||||
or list with :obj:`[None]` for each layer.
|
||||
"""
|
||||
if head_mask is not None:
|
||||
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
|
||||
@@ -557,7 +608,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `tensorflow index checkpoint file` (e.g, `./tf_model/model.ckpt.index`). In
|
||||
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
|
||||
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
|
||||
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
|
||||
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
@@ -610,7 +661,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
Whether or not to only look at local files (e.g., not try doanloading the model).
|
||||
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
|
||||
our S3 (faster).
|
||||
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
|
||||
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||
@@ -870,10 +921,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
"""
|
||||
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
||||
|
||||
Basically works like a linear layer but the weights are transposed.
|
||||
|
||||
Args:
|
||||
nf (:obj:`int`): The number of output features.
|
||||
nx (:obj:`int`): The number of input features.
|
||||
"""
|
||||
|
||||
def __init__(self, nf, nx):
|
||||
""" Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
|
||||
Basically works like a Linear layer but the weights are transposed
|
||||
"""
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
w = torch.empty(nx, nf)
|
||||
@@ -889,17 +947,31 @@ class Conv1D(nn.Module):
|
||||
|
||||
|
||||
class PoolerStartLogits(nn.Module):
|
||||
""" Compute SQuAD start_logits from sequence hidden states. """
|
||||
"""
|
||||
Compute SQuAD start logits from sequence hidden states.
|
||||
|
||||
def __init__(self, config):
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The config used by the model, will be used to grab the :obj:`hidden_size` of the model.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
def forward(self, hidden_states, p_mask=None):
|
||||
""" Args:
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
|
||||
invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
|
||||
The final hidden states of the model.
|
||||
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
|
||||
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
|
||||
1.0 means token should be masked.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.FloatTensor`: The start logits for SQuAD.
|
||||
"""
|
||||
x = self.dense(hidden_states).squeeze(-1)
|
||||
|
||||
@@ -913,28 +985,48 @@ class PoolerStartLogits(nn.Module):
|
||||
|
||||
|
||||
class PoolerEndLogits(nn.Module):
|
||||
""" Compute SQuAD end_logits from sequence hidden states and start token hidden state.
|
||||
"""
|
||||
Compute SQuAD end logits from sequence hidden states.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the
|
||||
:obj:`layer_norm_eps` to use.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dense_1 = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
|
||||
""" Args:
|
||||
One of ``start_states``, ``start_positions`` should be not None.
|
||||
If both are set, ``start_positions`` overrides ``start_states``.
|
||||
|
||||
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states
|
||||
hidden states of the first tokens for the labeled span.
|
||||
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the first token for the labeled span:
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
||||
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
start_states: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
p_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
|
||||
The final hidden states of the model.
|
||||
start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`):
|
||||
The hidden states of the first tokens for the labeled span.
|
||||
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
The position of the first token for the labeled span.
|
||||
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
|
||||
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
|
||||
1.0 means token should be masked.
|
||||
|
||||
.. note::
|
||||
|
||||
One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set,
|
||||
``start_positions`` overrides ``start_states``.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.FloatTensor`: The end logits for SQuAD.
|
||||
"""
|
||||
assert (
|
||||
start_states is not None or start_positions is not None
|
||||
@@ -960,7 +1052,13 @@ class PoolerEndLogits(nn.Module):
|
||||
|
||||
|
||||
class PoolerAnswerClass(nn.Module):
|
||||
""" Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
|
||||
"""
|
||||
Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The config used by the model, will be used to grab the :obj:`hidden_size` of the model.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@@ -968,23 +1066,33 @@ class PoolerAnswerClass(nn.Module):
|
||||
self.activation = nn.Tanh()
|
||||
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
||||
|
||||
def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
start_states: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
cls_index: Optional[torch.LongTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Args:
|
||||
One of ``start_states``, ``start_positions`` should be not None.
|
||||
If both are set, ``start_positions`` overrides ``start_states``.
|
||||
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
|
||||
The final hidden states of the model.
|
||||
start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`):
|
||||
The hidden states of the first tokens for the labeled span.
|
||||
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
The position of the first token for the labeled span.
|
||||
cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
|
||||
|
||||
**start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
|
||||
hidden states of the first tokens for the labeled span.
|
||||
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the first token for the labeled span.
|
||||
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
||||
position of the CLS token. If None, take the last token.
|
||||
.. note::
|
||||
|
||||
note(Original repo):
|
||||
no dependency on end_feature so that we can obtain one single `cls_logits`
|
||||
for each sample
|
||||
One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set,
|
||||
``start_positions`` overrides ``start_states``.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.FloatTensor`: The SQuAD 2.0 answer class.
|
||||
"""
|
||||
# No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
|
||||
hsz = hidden_states.shape[-1]
|
||||
assert (
|
||||
start_states is not None or start_positions is not None
|
||||
@@ -1009,7 +1117,7 @@ class PoolerAnswerClass(nn.Module):
|
||||
@dataclass
|
||||
class SquadHeadOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of question answering models using a :obj:`SquadHead`.
|
||||
Base class for outputs of question answering models using a :class:`~transformers.modeling_utils.SQuADHead`.
|
||||
|
||||
Args:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
|
||||
@@ -1036,44 +1144,13 @@ class SquadHeadOutput(ModelOutput):
|
||||
|
||||
|
||||
class SQuADHead(nn.Module):
|
||||
r""" A SQuAD head inspired by XLNet.
|
||||
r"""
|
||||
A SQuAD head inspired by XLNet.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
|
||||
|
||||
Inputs:
|
||||
**hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
|
||||
hidden states of sequence tokens
|
||||
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the first token for the labeled span.
|
||||
**end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the last token for the labeled span.
|
||||
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
||||
position of the CLS token. If None, take the last token.
|
||||
**is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
Whether the question has a possible answer in the paragraph or not.
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
||||
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||
1.0 means token should be masked.
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
||||
**start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
||||
**start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Indices for the top config.start_n_top start token possibilities (beam-search).
|
||||
**end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size,)``
|
||||
Log probabilities for the ``is_impossible`` label of the answers.
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the
|
||||
:obj:`layer_norm_eps` to use.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
@@ -1085,16 +1162,37 @@ class SQuADHead(nn.Module):
|
||||
self.end_logits = PoolerEndLogits(config)
|
||||
self.answer_class = PoolerAnswerClass(config)
|
||||
|
||||
@replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
cls_index=None,
|
||||
is_impossible=None,
|
||||
p_mask=None,
|
||||
return_tuple=False,
|
||||
):
|
||||
hidden_states: torch.FloatTensor,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
cls_index: Optional[torch.LongTensor] = None,
|
||||
is_impossible: Optional[torch.LongTensor] = None,
|
||||
p_mask: Optional[torch.FloatTensor] = None,
|
||||
return_tuple: bool = False,
|
||||
) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
|
||||
Final hidden states of the model on the sequence tokens.
|
||||
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Positions of the first token for the labeled span.
|
||||
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Positions of the last token for the labeled span.
|
||||
cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
|
||||
is_impossible (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Whether the question has a possible answer in the paragraph or not.
|
||||
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
|
||||
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
|
||||
1.0 means token should be masked.
|
||||
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to return a plain tuple instead of a :class:`~transformers.file_utils.ModelOuput`.
|
||||
|
||||
Returns:
|
||||
"""
|
||||
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
@@ -1163,19 +1261,31 @@ class SQuADHead(nn.Module):
|
||||
|
||||
|
||||
class SequenceSummary(nn.Module):
|
||||
r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
|
||||
Args of the config class:
|
||||
summary_type:
|
||||
- 'last' => [default] take the last token hidden state (like XLNet)
|
||||
- 'first' => take the first token hidden state (like Bert)
|
||||
- 'mean' => take the mean of all tokens hidden states
|
||||
- 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
|
||||
- 'attn' => Not implemented now, use multi-head attention
|
||||
summary_use_proj: Add a projection after the vector extraction
|
||||
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
|
||||
summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
|
||||
summary_first_dropout: Add a dropout before the projection and activation
|
||||
summary_last_dropout: Add a dropout after the projection and activation
|
||||
r"""
|
||||
Compute a single vector summary of a sequence hidden states.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The config used by the model. Relevant arguments in the config class of the model are (refer to the
|
||||
actual config class of your model for the default values it uses):
|
||||
|
||||
- **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are:
|
||||
|
||||
- :obj:`"last"` -- Take the last token hidden state (like XLNet)
|
||||
- :obj:`"first"` -- Take the first token hidden state (like Bert)
|
||||
- :obj:`"mean"` -- Take the mean of all tokens hidden states
|
||||
- :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
||||
- :obj:`"attn"` -- Not implemented now, use multi-head attention
|
||||
|
||||
- **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction.
|
||||
- **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to
|
||||
:obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`).
|
||||
- **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the
|
||||
output, another string or :obj:`None` will add no activation.
|
||||
- **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and
|
||||
activation.
|
||||
- **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and
|
||||
activation.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
@@ -1207,12 +1317,21 @@ class SequenceSummary(nn.Module):
|
||||
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
||||
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||
|
||||
def forward(self, hidden_states, cls_index=None):
|
||||
""" hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
|
||||
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
|
||||
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
|
||||
if summary_type == 'cls_index' and cls_index is None:
|
||||
we take the last token of the sequence as classification token
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Compute a single vector summary of a sequence hidden states.
|
||||
|
||||
Args:
|
||||
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`[batch_size, seq_len, hidden_size]`):
|
||||
The hidden states of the last layer.
|
||||
cls_index (:obj:`torch.LongTensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`):
|
||||
Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification
|
||||
token.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.FloatTensor`: The summary of the sequence hidden states.
|
||||
"""
|
||||
if self.summary_type == "last":
|
||||
output = hidden_states[:, -1]
|
||||
@@ -1239,10 +1358,19 @@ class SequenceSummary(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def prune_linear_layer(layer, index, dim=0):
|
||||
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
||||
Return the pruned layer as a new layer with requires_grad=True.
|
||||
Used to remove heads.
|
||||
def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear:
|
||||
"""
|
||||
Prune a linear layer to keep only entries in index.
|
||||
|
||||
Used to remove heads.
|
||||
|
||||
Args:
|
||||
layer (:obj:`torch.nn.Linear`): The layer to prune.
|
||||
index (:obj:`torch.LongTensor`): The indices to keep in the layer.
|
||||
dim (:obj:`int`, `optional`, defaults to 0): The dimension on which to keep the indices.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.nn.Linear`: The pruned layer as a new layer with :obj:`requires_grad=True`.
|
||||
"""
|
||||
index = index.to(layer.weight.device)
|
||||
W = layer.weight.index_select(dim, index).clone().detach()
|
||||
@@ -1264,11 +1392,20 @@ def prune_linear_layer(layer, index, dim=0):
|
||||
return new_layer
|
||||
|
||||
|
||||
def prune_conv1d_layer(layer, index, dim=1):
|
||||
""" Prune a Conv1D layer (a model parameters) to keep only entries in index.
|
||||
A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
|
||||
Return the pruned layer as a new layer with requires_grad=True.
|
||||
Used to remove heads.
|
||||
def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
|
||||
"""
|
||||
Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
|
||||
are transposed.
|
||||
|
||||
Used to remove heads.
|
||||
|
||||
Args:
|
||||
layer (:class:`~transformers.modeling_utils.Conv1D`): The layer to prune.
|
||||
index (:obj:`torch.LongTensor`): The indices to keep in the layer.
|
||||
dim (:obj:`int`, `optional`, defaults to 1): The dimension on which to keep the indices.
|
||||
|
||||
Returns:
|
||||
:class:`~transformers.modeling_utils.Conv1D`: The pruned layer as a new layer with :obj:`requires_grad=True`.
|
||||
"""
|
||||
index = index.to(layer.weight.device)
|
||||
W = layer.weight.index_select(dim, index).clone().detach()
|
||||
@@ -1288,10 +1425,22 @@ def prune_conv1d_layer(layer, index, dim=1):
|
||||
return new_layer
|
||||
|
||||
|
||||
def prune_layer(layer, index, dim=None):
|
||||
""" Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
|
||||
Return the pruned layer as a new layer with requires_grad=True.
|
||||
Used to remove heads.
|
||||
def prune_layer(
|
||||
layer: Union[torch.nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
|
||||
) -> Union[torch.nn.Linear, Conv1D]:
|
||||
"""
|
||||
Prune a Conv1D or linear layer to keep only entries in index.
|
||||
|
||||
Used to remove heads.
|
||||
|
||||
Args:
|
||||
layer (:obj:`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
|
||||
index (:obj:`torch.LongTensor`): The indices to keep in the layer.
|
||||
dim (:obj:`int`, `optional`): The dimension on which to keep the indices.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.nn.Linear` or :class:`~transformers.modeling_utils.Conv1D`:
|
||||
The pruned layer as a new layer with :obj:`requires_grad=True`.
|
||||
"""
|
||||
if isinstance(layer, nn.Linear):
|
||||
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
|
||||
|
||||
Reference in New Issue
Block a user