Add caching mechanism to BERT, RoBERTa (#9183)
* add past_key_values * add use_cache option * make mask before cutting ids * adjust position_ids according to past_key_values * flatten past_key_values * fix positional embeds * fix _reorder_cache * set use_cache to false when not decoder, fix attention mask init * add test for caching * add past_key_values for Roberta * fix position embeds * add caching test for roberta * add doc * make style * doc, fix attention mask, test * small fixes * adress patrick's comments * input_ids shouldn't start with pad token * use_cache only when decoder * make consistent with bert * make copies consistent * add use_cache to encoder * add past_key_values to tapas attention * apply suggestions from code review * make coppies consistent * add attn mask in tests * remove copied from longformer * apply suggestions from code review * fix bart test * nit * simplify model outputs * fix doc * fix output ordering
This commit is contained in:
@@ -126,13 +126,6 @@ CausalLMOutputWithCrossAttentions
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
CausalLMOutputWithPastAndCrossAttentions
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithPastAndCrossAttentions
|
|
||||||
:members:
|
|
||||||
|
|
||||||
|
|
||||||
CausalLMOutputWithPast
|
CausalLMOutputWithPast
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -175,11 +175,19 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
|
|||||||
|
|
||||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
weighted average in the cross-attention heads.
|
weighted average in the cross-attention heads.
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` tuples of length :obj:`config.n_layers`, with each tuple containing the
|
||||||
|
cached key, value states of the self-attention and the cross-attention layers if model is used in
|
||||||
|
encoder-decoder setting. Only relevant if ``config.is_decoder = True``.
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||||
|
:obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
last_hidden_state: torch.FloatTensor = None
|
last_hidden_state: torch.FloatTensor = None
|
||||||
pooler_output: torch.FloatTensor = None
|
pooler_output: torch.FloatTensor = None
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
@@ -379,53 +387,18 @@ class CausalLMOutputWithCrossAttentions(ModelOutput):
|
|||||||
|
|
||||||
Cross attentions weights after the attention softmax, used to compute the weighted average in the
|
Cross attentions weights after the attention softmax, used to compute the weighted average in the
|
||||||
cross-attention heads.
|
cross-attention heads.
|
||||||
"""
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` tuples of length :obj:`config.n_layers`, with each tuple containing the
|
||||||
loss: Optional[torch.FloatTensor] = None
|
cached key, value states of the self-attention and the cross-attention layers if model is used in
|
||||||
logits: torch.FloatTensor = None
|
encoder-decoder setting. Only relevant if ``config.is_decoder = True``.
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
||||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CausalLMOutputWithPastAndCrossAttentions(ModelOutput):
|
|
||||||
"""
|
|
||||||
Base class for causal language model (or autoregressive) outputs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
|
||||||
Language modeling loss (for next-token prediction).
|
|
||||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
||||||
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
|
||||||
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
|
|
||||||
batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||||
:obj:`past_key_values` input) to speed up sequential decoding.
|
:obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
|
||||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
||||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
|
||||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
|
||||||
sequence_length, sequence_length)`.
|
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
||||||
heads.
|
|
||||||
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
|
||||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
|
||||||
sequence_length, sequence_length)`.
|
|
||||||
|
|
||||||
Cross attentions weights after the attention softmax, used to compute the weighted average in the
|
|
||||||
cross-attention heads.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
logits: torch.FloatTensor = None
|
logits: torch.FloatTensor = None
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|||||||
@@ -217,7 +217,9 @@ class AlbertEmbeddings(nn.Module):
|
|||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
||||||
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
def forward(
|
||||||
|
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||||
|
):
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
else:
|
else:
|
||||||
@@ -226,7 +228,7 @@ class AlbertEmbeddings(nn.Module):
|
|||||||
seq_length = input_shape[1]
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = self.position_ids[:, :seq_length]
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||||
|
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||||
|
|||||||
@@ -98,6 +98,9 @@ class BertConfig(PretrainedConfig):
|
|||||||
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
|
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
|
||||||
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
|
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
|
||||||
<https://arxiv.org/abs/2009.13658>`__.
|
<https://arxiv.org/abs/2009.13658>`__.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if ``config.is_decoder=True``.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -131,6 +134,7 @@ class BertConfig(PretrainedConfig):
|
|||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
gradient_checkpointing=False,
|
gradient_checkpointing=False,
|
||||||
position_embedding_type="absolute",
|
position_embedding_type="absolute",
|
||||||
|
use_cache=True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||||
@@ -149,3 +153,4 @@ class BertConfig(PretrainedConfig):
|
|||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
self.gradient_checkpointing = gradient_checkpointing
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
self.position_embedding_type = position_embedding_type
|
self.position_embedding_type = position_embedding_type
|
||||||
|
self.use_cache = use_cache
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from ...file_utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
MaskedLMOutput,
|
MaskedLMOutput,
|
||||||
@@ -180,7 +180,9 @@ class BertEmbeddings(nn.Module):
|
|||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
|
||||||
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
def forward(
|
||||||
|
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||||
|
):
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
else:
|
else:
|
||||||
@@ -189,7 +191,7 @@ class BertEmbeddings(nn.Module):
|
|||||||
seq_length = input_shape[1]
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = self.position_ids[:, :seq_length]
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||||
|
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||||
@@ -230,6 +232,8 @@ class BertSelfAttention(nn.Module):
|
|||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||||
|
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(*new_x_shape)
|
||||||
@@ -242,6 +246,7 @@ class BertSelfAttention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
@@ -249,17 +254,37 @@ class BertSelfAttention(nn.Module):
|
|||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
# and values come from an encoder; the attention mask needs to be
|
# and values come from an encoder; the attention mask needs to be
|
||||||
# such that the encoder's padding tokens are not attended to.
|
# such that the encoder's padding tokens are not attended to.
|
||||||
if encoder_hidden_states is not None:
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
mixed_key_layer = self.key(encoder_hidden_states)
|
|
||||||
mixed_value_layer = self.value(encoder_hidden_states)
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
attention_mask = encoder_attention_mask
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||||
else:
|
else:
|
||||||
mixed_key_layer = self.key(hidden_states)
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
mixed_value_layer = self.value(hidden_states)
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
@@ -303,6 +328,9 @@ class BertSelfAttention(nn.Module):
|
|||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -352,6 +380,7 @@ class BertAttention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
@@ -360,6 +389,7 @@ class BertAttention(nn.Module):
|
|||||||
head_mask,
|
head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
@@ -417,36 +447,60 @@ class BertLayer(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
)
|
)
|
||||||
attention_output = self_attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
if self.is_decoder and encoder_hidden_states is not None:
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
self, "crossattention"
|
self, "crossattention"
|
||||||
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
cross_attention_outputs = self.crossattention(
|
cross_attention_outputs = self.crossattention(
|
||||||
attention_output,
|
attention_output,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
cross_attn_past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = cross_attention_outputs[0]
|
attention_output = cross_attention_outputs[0]
|
||||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
layer_output = apply_chunking_to_forward(
|
layer_output = apply_chunking_to_forward(
|
||||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||||
)
|
)
|
||||||
outputs = (layer_output,) + outputs
|
outputs = (layer_output,) + outputs
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def feed_forward_chunk(self, attention_output):
|
def feed_forward_chunk(self, attention_output):
|
||||||
@@ -468,6 +522,8 @@ class BertEncoder(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
@@ -475,17 +531,19 @@ class BertEncoder(nn.Module):
|
|||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False):
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
return module(*inputs, output_attentions)
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@@ -504,9 +562,13 @@ class BertEncoder(nn.Module):
|
|||||||
layer_head_mask,
|
layer_head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
if self.config.add_cross_attention:
|
if self.config.add_cross_attention:
|
||||||
@@ -518,11 +580,18 @@ class BertEncoder(nn.Module):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v
|
v
|
||||||
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_decoder_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attentions,
|
||||||
|
all_cross_attentions,
|
||||||
|
]
|
||||||
if v is not None
|
if v is not None
|
||||||
)
|
)
|
||||||
return BaseModelOutputWithCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
cross_attentions=all_cross_attentions,
|
cross_attentions=all_cross_attentions,
|
||||||
@@ -799,6 +868,8 @@ class BertModel(BertPreTrainedModel):
|
|||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -813,6 +884,15 @@ class BertModel(BertPreTrainedModel):
|
|||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
- 1 for tokens that are **not masked**,
|
||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
"""
|
"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -820,19 +900,29 @@ class BertModel(BertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if self.config.is_decoder:
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
else:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
|
# past_key_values_length
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(input_shape, device=device)
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
@@ -859,7 +949,11 @@ class BertModel(BertPreTrainedModel):
|
|||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
embedding_output = self.embeddings(
|
embedding_output = self.embeddings(
|
||||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
)
|
)
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
@@ -867,6 +961,8 @@ class BertModel(BertPreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_extended_attention_mask,
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -880,6 +976,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
pooler_output=pooled_output,
|
pooler_output=pooled_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
attentions=encoder_outputs.attentions,
|
attentions=encoder_outputs.attentions,
|
||||||
cross_attentions=encoder_outputs.cross_attentions,
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
@@ -1029,6 +1126,8 @@ class BertLMHeadModel(BertPreTrainedModel):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -1047,6 +1146,15 @@ class BertLMHeadModel(BertPreTrainedModel):
|
|||||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||||
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
||||||
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@@ -1066,6 +1174,8 @@ class BertLMHeadModel(BertPreTrainedModel):
|
|||||||
>>> prediction_logits = outputs.logits
|
>>> prediction_logits = outputs.logits
|
||||||
"""
|
"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
if labels is not None:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
outputs = self.bert(
|
outputs = self.bert(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -1076,6 +1186,8 @@ class BertLMHeadModel(BertPreTrainedModel):
|
|||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -1099,20 +1211,30 @@ class BertLMHeadModel(BertPreTrainedModel):
|
|||||||
return CausalLMOutputWithCrossAttentions(
|
return CausalLMOutputWithCrossAttentions(
|
||||||
loss=lm_loss,
|
loss=lm_loss,
|
||||||
logits=prediction_scores,
|
logits=prediction_scores,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
cross_attentions=outputs.cross_attentions,
|
cross_attentions=outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||||
input_shape = input_ids.shape
|
input_shape = input_ids.shape
|
||||||
|
|
||||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = input_ids.new_ones(input_shape)
|
attention_mask = input_ids.new_ones(input_shape)
|
||||||
|
|
||||||
|
# cut decoder_input_ids if past is used
|
||||||
|
if past is not None:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
|
|
||||||
|
def _reorder_cache(self, past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||||
class BertForMaskedLM(BertPreTrainedModel):
|
class BertForMaskedLM(BertPreTrainedModel):
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from ...file_utils import (
|
|||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions
|
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..bert.modeling_bert import BertEncoder
|
from ..bert.modeling_bert import BertEncoder
|
||||||
@@ -144,7 +144,7 @@ class BertGenerationEmbeddings(nn.Module):
|
|||||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
|
|
||||||
def forward(self, input_ids=None, position_ids=None, inputs_embeds=None):
|
def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0):
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
else:
|
else:
|
||||||
@@ -153,7 +153,7 @@ class BertGenerationEmbeddings(nn.Module):
|
|||||||
seq_length = input_shape[1]
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = self.position_ids[:, :seq_length]
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
@@ -297,7 +297,7 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
|||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder",
|
checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder",
|
||||||
output_type=BaseModelOutputWithCrossAttentions,
|
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -309,6 +309,8 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
|||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -321,6 +323,15 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
|||||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: ``1`` for
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: ``1`` for
|
||||||
tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
"""
|
"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -332,19 +343,28 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
|||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
|
# past_key_values_length
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(input_shape, device=device)
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||||
|
|
||||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
extended_attention_mask = None
|
||||||
|
if not use_cache:
|
||||||
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||||
|
attention_mask, input_shape, device
|
||||||
|
)
|
||||||
|
|
||||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
@@ -364,7 +384,12 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
|||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)
|
embedding_output = self.embeddings(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
|
)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
@@ -372,6 +397,8 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_extended_attention_mask,
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -381,8 +408,9 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (sequence_output,) + encoder_outputs[1:]
|
return (sequence_output,) + encoder_outputs[1:]
|
||||||
|
|
||||||
return BaseModelOutputWithCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
attentions=encoder_outputs.attentions,
|
attentions=encoder_outputs.attentions,
|
||||||
cross_attentions=encoder_outputs.cross_attentions,
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
@@ -437,6 +465,8 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -455,6 +485,15 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
|||||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||||
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
||||||
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@@ -474,6 +513,8 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
|||||||
>>> prediction_logits = outputs.logits
|
>>> prediction_logits = outputs.logits
|
||||||
"""
|
"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
if labels is not None:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
outputs = self.bert(
|
outputs = self.bert(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -483,6 +524,8 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
|||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -506,16 +549,26 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
|||||||
return CausalLMOutputWithCrossAttentions(
|
return CausalLMOutputWithCrossAttentions(
|
||||||
loss=lm_loss,
|
loss=lm_loss,
|
||||||
logits=prediction_scores,
|
logits=prediction_scores,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
cross_attentions=outputs.cross_attentions,
|
cross_attentions=outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||||
input_shape = input_ids.shape
|
input_shape = input_ids.shape
|
||||||
|
|
||||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = input_ids.new_ones(input_shape)
|
attention_mask = input_ids.new_ones(input_shape)
|
||||||
|
|
||||||
|
# cut decoder_input_ids if past is used
|
||||||
|
if past is not None:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
|
|
||||||
|
def _reorder_cache(self, past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from ...file_utils import (
|
|||||||
)
|
)
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithCrossAttentions,
|
BaseModelOutputWithCrossAttentions,
|
||||||
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
MaskedLMOutput,
|
MaskedLMOutput,
|
||||||
MultipleChoiceModelOutput,
|
MultipleChoiceModelOutput,
|
||||||
QuestionAnsweringModelOutput,
|
QuestionAnsweringModelOutput,
|
||||||
@@ -168,7 +169,9 @@ class ElectraEmbeddings(nn.Module):
|
|||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
||||||
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
def forward(
|
||||||
|
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||||
|
):
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
else:
|
else:
|
||||||
@@ -177,7 +180,7 @@ class ElectraEmbeddings(nn.Module):
|
|||||||
seq_length = input_shape[1]
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = self.position_ids[:, :seq_length]
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||||
|
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||||
@@ -219,6 +222,8 @@ class ElectraSelfAttention(nn.Module):
|
|||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||||
|
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(*new_x_shape)
|
||||||
@@ -231,6 +236,7 @@ class ElectraSelfAttention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
@@ -238,17 +244,37 @@ class ElectraSelfAttention(nn.Module):
|
|||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
# and values come from an encoder; the attention mask needs to be
|
# and values come from an encoder; the attention mask needs to be
|
||||||
# such that the encoder's padding tokens are not attended to.
|
# such that the encoder's padding tokens are not attended to.
|
||||||
if encoder_hidden_states is not None:
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
mixed_key_layer = self.key(encoder_hidden_states)
|
|
||||||
mixed_value_layer = self.value(encoder_hidden_states)
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
attention_mask = encoder_attention_mask
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||||
else:
|
else:
|
||||||
mixed_key_layer = self.key(hidden_states)
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
mixed_value_layer = self.value(hidden_states)
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
@@ -292,6 +318,9 @@ class ElectraSelfAttention(nn.Module):
|
|||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -343,6 +372,7 @@ class ElectraAttention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
@@ -351,6 +381,7 @@ class ElectraAttention(nn.Module):
|
|||||||
head_mask,
|
head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
@@ -411,36 +442,60 @@ class ElectraLayer(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
)
|
)
|
||||||
attention_output = self_attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
if self.is_decoder and encoder_hidden_states is not None:
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
self, "crossattention"
|
self, "crossattention"
|
||||||
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
cross_attention_outputs = self.crossattention(
|
cross_attention_outputs = self.crossattention(
|
||||||
attention_output,
|
attention_output,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
cross_attn_past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = cross_attention_outputs[0]
|
attention_output = cross_attention_outputs[0]
|
||||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
layer_output = apply_chunking_to_forward(
|
layer_output = apply_chunking_to_forward(
|
||||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||||
)
|
)
|
||||||
outputs = (layer_output,) + outputs
|
outputs = (layer_output,) + outputs
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def feed_forward_chunk(self, attention_output):
|
def feed_forward_chunk(self, attention_output):
|
||||||
@@ -463,6 +518,8 @@ class ElectraEncoder(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
@@ -470,17 +527,19 @@ class ElectraEncoder(nn.Module):
|
|||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False):
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
return module(*inputs, output_attentions)
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@@ -499,9 +558,13 @@ class ElectraEncoder(nn.Module):
|
|||||||
layer_head_mask,
|
layer_head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
if self.config.add_cross_attention:
|
if self.config.add_cross_attention:
|
||||||
@@ -513,11 +576,18 @@ class ElectraEncoder(nn.Module):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v
|
v
|
||||||
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_decoder_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attentions,
|
||||||
|
all_cross_attentions,
|
||||||
|
]
|
||||||
if v is not None
|
if v is not None
|
||||||
)
|
)
|
||||||
return BaseModelOutputWithCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
cross_attentions=all_cross_attentions,
|
cross_attentions=all_cross_attentions,
|
||||||
|
|||||||
@@ -345,11 +345,11 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
past_key_values=None, # TODO: (PVP) implement :obj:`use_cache`
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
use_cache=None, # TODO: (PVP) implement :obj:`use_cache`
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -413,18 +413,19 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
labels=labels,
|
labels=labels,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
use_cache=use_cache,
|
||||||
|
past_key_values=past_key_values,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
**kwargs_decoder,
|
**kwargs_decoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(PVP): currently it is not possible to use `past`
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return decoder_outputs + encoder_outputs
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
return Seq2SeqLMOutput(
|
return Seq2SeqLMOutput(
|
||||||
loss=decoder_outputs.loss,
|
loss=decoder_outputs.loss,
|
||||||
logits=decoder_outputs.logits,
|
logits=decoder_outputs.logits,
|
||||||
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
cross_attentions=decoder_outputs.cross_attentions,
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
@@ -433,24 +434,19 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
encoder_attentions=encoder_outputs.attentions,
|
encoder_attentions=encoder_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, encoder_outputs=None, **kwargs):
|
def prepare_inputs_for_generation(
|
||||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
|
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||||
|
):
|
||||||
|
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
||||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||||
input_dict = {
|
input_dict = {
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
"decoder_input_ids": decoder_inputs["input_ids"],
|
"decoder_input_ids": decoder_inputs["input_ids"],
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": use_cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Ideally all models should have a :obj:`use_cache`
|
|
||||||
# leave following to ifs until all have it implemented
|
|
||||||
if "use_cache" in decoder_inputs:
|
|
||||||
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
|
|
||||||
|
|
||||||
if "past_key_values" in decoder_inputs:
|
|
||||||
input_dict["past_key_values"] = decoder_inputs["past_key_values"]
|
|
||||||
|
|
||||||
return input_dict
|
return input_dict
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
def _reorder_cache(self, past, beam_idx):
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from ...file_utils import (
|
|||||||
)
|
)
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
CausalLMOutputWithPastAndCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import (
|
from ...modeling_utils import (
|
||||||
@@ -851,7 +851,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint="gpt2",
|
checkpoint="gpt2",
|
||||||
output_type=CausalLMOutputWithPastAndCrossAttentions,
|
output_type=CausalLMOutputWithCrossAttentions,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -916,7 +916,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
output = (lm_logits,) + transformer_outputs[1:]
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return CausalLMOutputWithPastAndCrossAttentions(
|
return CausalLMOutputWithCrossAttentions(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=lm_logits,
|
logits=lm_logits,
|
||||||
past_key_values=transformer_outputs.past_key_values,
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
MaskedLMOutput,
|
MaskedLMOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
@@ -151,6 +151,8 @@ class LayoutLMSelfAttention(nn.Module):
|
|||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||||
|
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(*new_x_shape)
|
||||||
@@ -163,6 +165,7 @@ class LayoutLMSelfAttention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
@@ -170,17 +173,37 @@ class LayoutLMSelfAttention(nn.Module):
|
|||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
# and values come from an encoder; the attention mask needs to be
|
# and values come from an encoder; the attention mask needs to be
|
||||||
# such that the encoder's padding tokens are not attended to.
|
# such that the encoder's padding tokens are not attended to.
|
||||||
if encoder_hidden_states is not None:
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
mixed_key_layer = self.key(encoder_hidden_states)
|
|
||||||
mixed_value_layer = self.value(encoder_hidden_states)
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
attention_mask = encoder_attention_mask
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||||
else:
|
else:
|
||||||
mixed_key_layer = self.key(hidden_states)
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
mixed_value_layer = self.value(hidden_states)
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
@@ -224,6 +247,9 @@ class LayoutLMSelfAttention(nn.Module):
|
|||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -275,6 +301,7 @@ class LayoutLMAttention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
@@ -283,6 +310,7 @@ class LayoutLMAttention(nn.Module):
|
|||||||
head_mask,
|
head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
@@ -343,36 +371,60 @@ class LayoutLMLayer(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
)
|
)
|
||||||
attention_output = self_attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
if self.is_decoder and encoder_hidden_states is not None:
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
self, "crossattention"
|
self, "crossattention"
|
||||||
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
cross_attention_outputs = self.crossattention(
|
cross_attention_outputs = self.crossattention(
|
||||||
attention_output,
|
attention_output,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
cross_attn_past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = cross_attention_outputs[0]
|
attention_output = cross_attention_outputs[0]
|
||||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
layer_output = apply_chunking_to_forward(
|
layer_output = apply_chunking_to_forward(
|
||||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||||
)
|
)
|
||||||
outputs = (layer_output,) + outputs
|
outputs = (layer_output,) + outputs
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def feed_forward_chunk(self, attention_output):
|
def feed_forward_chunk(self, attention_output):
|
||||||
@@ -395,6 +447,8 @@ class LayoutLMEncoder(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
@@ -402,17 +456,19 @@ class LayoutLMEncoder(nn.Module):
|
|||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False):
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
return module(*inputs, output_attentions)
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@@ -431,9 +487,13 @@ class LayoutLMEncoder(nn.Module):
|
|||||||
layer_head_mask,
|
layer_head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
if self.config.add_cross_attention:
|
if self.config.add_cross_attention:
|
||||||
@@ -445,11 +505,18 @@ class LayoutLMEncoder(nn.Module):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v
|
v
|
||||||
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_decoder_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attentions,
|
||||||
|
all_cross_attentions,
|
||||||
|
]
|
||||||
if v is not None
|
if v is not None
|
||||||
)
|
)
|
||||||
return BaseModelOutputWithCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
cross_attentions=all_cross_attentions,
|
cross_attentions=all_cross_attentions,
|
||||||
|
|||||||
@@ -424,7 +424,6 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
|
|
||||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||||
"""
|
"""
|
||||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from ...file_utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
MaskedLMOutput,
|
MaskedLMOutput,
|
||||||
@@ -91,25 +91,23 @@ class RobertaEmbeddings(nn.Module):
|
|||||||
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
|
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
def forward(
|
||||||
|
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||||
|
):
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||||
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device)
|
position_ids = create_position_ids_from_input_ids(
|
||||||
|
input_ids, self.padding_idx, past_key_values_length
|
||||||
|
).to(input_ids.device)
|
||||||
else:
|
else:
|
||||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
else:
|
else:
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
|
||||||
seq_length = input_shape[1]
|
|
||||||
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = self.position_ids[:, :seq_length]
|
|
||||||
|
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||||
|
|
||||||
@@ -167,6 +165,8 @@ class RobertaSelfAttention(nn.Module):
|
|||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||||
|
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(*new_x_shape)
|
||||||
@@ -179,6 +179,7 @@ class RobertaSelfAttention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
@@ -186,17 +187,37 @@ class RobertaSelfAttention(nn.Module):
|
|||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
# and values come from an encoder; the attention mask needs to be
|
# and values come from an encoder; the attention mask needs to be
|
||||||
# such that the encoder's padding tokens are not attended to.
|
# such that the encoder's padding tokens are not attended to.
|
||||||
if encoder_hidden_states is not None:
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
mixed_key_layer = self.key(encoder_hidden_states)
|
|
||||||
mixed_value_layer = self.value(encoder_hidden_states)
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
attention_mask = encoder_attention_mask
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||||
else:
|
else:
|
||||||
mixed_key_layer = self.key(hidden_states)
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
mixed_value_layer = self.value(hidden_states)
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
@@ -240,6 +261,9 @@ class RobertaSelfAttention(nn.Module):
|
|||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -291,6 +315,7 @@ class RobertaAttention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
@@ -299,6 +324,7 @@ class RobertaAttention(nn.Module):
|
|||||||
head_mask,
|
head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
@@ -359,36 +385,60 @@ class RobertaLayer(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
)
|
)
|
||||||
attention_output = self_attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
if self.is_decoder and encoder_hidden_states is not None:
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
self, "crossattention"
|
self, "crossattention"
|
||||||
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
cross_attention_outputs = self.crossattention(
|
cross_attention_outputs = self.crossattention(
|
||||||
attention_output,
|
attention_output,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
cross_attn_past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = cross_attention_outputs[0]
|
attention_output = cross_attention_outputs[0]
|
||||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
layer_output = apply_chunking_to_forward(
|
layer_output = apply_chunking_to_forward(
|
||||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||||
)
|
)
|
||||||
outputs = (layer_output,) + outputs
|
outputs = (layer_output,) + outputs
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def feed_forward_chunk(self, attention_output):
|
def feed_forward_chunk(self, attention_output):
|
||||||
@@ -411,6 +461,8 @@ class RobertaEncoder(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
@@ -418,17 +470,19 @@ class RobertaEncoder(nn.Module):
|
|||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False):
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
return module(*inputs, output_attentions)
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@@ -447,9 +501,13 @@ class RobertaEncoder(nn.Module):
|
|||||||
layer_head_mask,
|
layer_head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
if self.config.add_cross_attention:
|
if self.config.add_cross_attention:
|
||||||
@@ -461,11 +519,18 @@ class RobertaEncoder(nn.Module):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v
|
v
|
||||||
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_decoder_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attentions,
|
||||||
|
all_cross_attentions,
|
||||||
|
]
|
||||||
if v is not None
|
if v is not None
|
||||||
)
|
)
|
||||||
return BaseModelOutputWithCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
cross_attentions=all_cross_attentions,
|
cross_attentions=all_cross_attentions,
|
||||||
@@ -646,6 +711,8 @@ class RobertaModel(RobertaPreTrainedModel):
|
|||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -658,26 +725,44 @@ class RobertaModel(RobertaPreTrainedModel):
|
|||||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: ``1`` for
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: ``1`` for
|
||||||
tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
"""
|
"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
if not self.config.is_decoder:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
|
# past_key_values_length
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(input_shape, device=device)
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
@@ -704,7 +789,11 @@ class RobertaModel(RobertaPreTrainedModel):
|
|||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
embedding_output = self.embeddings(
|
embedding_output = self.embeddings(
|
||||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
)
|
)
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
@@ -712,6 +801,8 @@ class RobertaModel(RobertaPreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_extended_attention_mask,
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -725,6 +816,7 @@ class RobertaModel(RobertaPreTrainedModel):
|
|||||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
pooler_output=pooled_output,
|
pooler_output=pooled_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
attentions=encoder_outputs.attentions,
|
attentions=encoder_outputs.attentions,
|
||||||
cross_attentions=encoder_outputs.cross_attentions,
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
@@ -768,6 +860,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -787,6 +881,15 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
|||||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||||
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
||||||
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@@ -806,6 +909,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
|||||||
>>> prediction_logits = outputs.logits
|
>>> prediction_logits = outputs.logits
|
||||||
"""
|
"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
if labels is not None:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
outputs = self.roberta(
|
outputs = self.roberta(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -816,6 +921,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
|||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -839,20 +946,30 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
|||||||
return CausalLMOutputWithCrossAttentions(
|
return CausalLMOutputWithCrossAttentions(
|
||||||
loss=lm_loss,
|
loss=lm_loss,
|
||||||
logits=prediction_scores,
|
logits=prediction_scores,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
cross_attentions=outputs.cross_attentions,
|
cross_attentions=outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||||
input_shape = input_ids.shape
|
input_shape = input_ids.shape
|
||||||
|
|
||||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = input_ids.new_ones(input_shape)
|
attention_mask = input_ids.new_ones(input_shape)
|
||||||
|
|
||||||
|
# cut decoder_input_ids if past is used
|
||||||
|
if past is not None:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
|
|
||||||
|
def _reorder_cache(self, past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
||||||
class RobertaForMaskedLM(RobertaPreTrainedModel):
|
class RobertaForMaskedLM(RobertaPreTrainedModel):
|
||||||
@@ -1357,7 +1474,7 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
||||||
"""
|
"""
|
||||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||||
@@ -1369,5 +1486,5 @@ def create_position_ids_from_input_ids(input_ids, padding_idx):
|
|||||||
"""
|
"""
|
||||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||||
mask = input_ids.ne(padding_idx).int()
|
mask = input_ids.ne(padding_idx).int()
|
||||||
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
|
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
|
||||||
return incremental_indices.long() + padding_idx
|
return incremental_indices.long() + padding_idx
|
||||||
|
|||||||
@@ -347,6 +347,7 @@ class TapasSelfAttention(nn.Module):
|
|||||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
@@ -360,6 +361,7 @@ class TapasSelfAttention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
@@ -367,17 +369,30 @@ class TapasSelfAttention(nn.Module):
|
|||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
# and values come from an encoder; the attention mask needs to be
|
# and values come from an encoder; the attention mask needs to be
|
||||||
# such that the encoder's padding tokens are not attended to.
|
# such that the encoder's padding tokens are not attended to.
|
||||||
if encoder_hidden_states is not None:
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
mixed_key_layer = self.key(encoder_hidden_states)
|
|
||||||
mixed_value_layer = self.value(encoder_hidden_states)
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
attention_mask = encoder_attention_mask
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||||
else:
|
else:
|
||||||
mixed_key_layer = self.key(hidden_states)
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
mixed_value_layer = self.value(hidden_states)
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
if self.is_decoder:
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
@@ -404,6 +419,8 @@ class TapasSelfAttention(nn.Module):
|
|||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -455,6 +472,7 @@ class TapasAttention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
@@ -463,6 +481,7 @@ class TapasAttention(nn.Module):
|
|||||||
head_mask,
|
head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
@@ -523,36 +542,60 @@ class TapasLayer(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
)
|
)
|
||||||
attention_output = self_attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
if self.is_decoder and encoder_hidden_states is not None:
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
self, "crossattention"
|
self, "crossattention"
|
||||||
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
cross_attention_outputs = self.crossattention(
|
cross_attention_outputs = self.crossattention(
|
||||||
attention_output,
|
attention_output,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
cross_attn_past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = cross_attention_outputs[0]
|
attention_output = cross_attention_outputs[0]
|
||||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
layer_output = apply_chunking_to_forward(
|
layer_output = apply_chunking_to_forward(
|
||||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||||
)
|
)
|
||||||
outputs = (layer_output,) + outputs
|
outputs = (layer_output,) + outputs
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def feed_forward_chunk(self, attention_output):
|
def feed_forward_chunk(self, attention_output):
|
||||||
@@ -574,6 +617,8 @@ class TapasEncoder(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
@@ -590,7 +635,7 @@ class TapasEncoder(nn.Module):
|
|||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
return module(*inputs, output_attentions)
|
return module(*inputs, past_key_values, output_attentions)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@@ -609,6 +654,7 @@ class TapasEncoder(nn.Module):
|
|||||||
layer_head_mask,
|
layer_head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ class BartModelTester:
|
|||||||
input_ids = inputs_dict["input_ids"]
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
|
||||||
# first forward pass
|
# first forward pass
|
||||||
outputs = model(input_ids, use_cache=True)
|
outputs = model(input_ids, attention_mask=inputs_dict["attention_mask"], use_cache=True)
|
||||||
|
|
||||||
output, past_key_values = outputs.to_tuple()
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
|
||||||
|
|||||||
@@ -260,6 +260,66 @@ class BertModelTester:
|
|||||||
)
|
)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past_large_inputs(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
config.is_decoder = True
|
||||||
|
config.add_cross_attention = True
|
||||||
|
model = BertLMHeadModel(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
past_key_values = outputs.past_key_values
|
||||||
|
|
||||||
|
# create hypothetical multiple next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||||
|
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(
|
||||||
|
next_input_ids,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
output_from_past = model(
|
||||||
|
next_tokens,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||||
|
|
||||||
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
def create_and_check_for_next_sequence_prediction(
|
def create_and_check_for_next_sequence_prediction(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
@@ -454,6 +514,10 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs)
|
self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_past_with_large_inputs(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|
||||||
def test_for_multiple_choice(self):
|
def test_for_multiple_choice(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, r
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
from transformers import BertGenerationConfig, BertGenerationDecoder, BertGenerationEncoder
|
from transformers import BertGenerationConfig, BertGenerationDecoder, BertGenerationEncoder
|
||||||
|
|
||||||
|
|
||||||
@@ -156,6 +158,64 @@ class BertGenerationEncoderTester:
|
|||||||
)
|
)
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past_large_inputs(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
config.is_decoder = True
|
||||||
|
config.add_cross_attention = True
|
||||||
|
model = BertGenerationDecoder(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
past_key_values = outputs.past_key_values
|
||||||
|
|
||||||
|
# create hypothetical multiple next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||||
|
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(
|
||||||
|
next_input_ids,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
output_from_past = model(
|
||||||
|
next_tokens,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||||
|
|
||||||
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
def create_and_check_for_causal_lm(
|
def create_and_check_for_causal_lm(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@@ -203,6 +263,10 @@ class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_past_with_large_inputs(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|
||||||
def test_model_as_decoder_with_default_input_mask(self):
|
def test_model_as_decoder_with_default_input_mask(self):
|
||||||
# This regression test was failing with PyTorch < 1.3
|
# This regression test was failing with PyTorch < 1.3
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -198,6 +198,74 @@ class RobertaModelTester:
|
|||||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past_large_inputs(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
config.is_decoder = True
|
||||||
|
config.add_cross_attention = True
|
||||||
|
model = RobertaForCausalLM(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
|
# make sure that ids don't start with pad token
|
||||||
|
mask = input_ids.ne(config.pad_token_id).long()
|
||||||
|
input_ids = input_ids * mask
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
past_key_values = outputs.past_key_values
|
||||||
|
|
||||||
|
# create hypothetical multiple next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||||
|
|
||||||
|
# make sure that ids don't start with pad token
|
||||||
|
mask = next_tokens.ne(config.pad_token_id).long()
|
||||||
|
next_tokens = next_tokens * mask
|
||||||
|
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(
|
||||||
|
next_input_ids,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
output_from_past = model(
|
||||||
|
next_tokens,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||||
|
|
||||||
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
def create_and_check_for_masked_lm(
|
def create_and_check_for_masked_lm(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
@@ -337,6 +405,10 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_past_with_large_inputs(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|
||||||
def test_for_masked_lm(self):
|
def test_for_masked_lm(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user