TF BART models - Add cross_attentions to model output and fix cross-attention head masking (#10699)

* Add cross_attn_head_mask to BART

* Fix cross_attentions in TFBart-like models

* This commit enables returning of `cross_attentions`
for TFBart-like models

* It also fixes attention head masking in cross-attenion module

* Update TF model templates

* Fix missing , in TF model templates

* Fix typo: congig -> config
This commit is contained in:
Daniel Stancl
2021-04-26 14:16:21 +02:00
committed by GitHub
parent 4bd6b54fa4
commit 38a716cd41
15 changed files with 643 additions and 216 deletions

View File

@@ -116,6 +116,82 @@ class TFBaseModelOutputWithPast(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBaseModelOutputWithCrossAttentions(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(tf.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
1, hidden_size)` is output.
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size,
num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(tf.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
last_hidden_state: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass @dataclass
class TFSeq2SeqModelOutput(ModelOutput): class TFSeq2SeqModelOutput(ModelOutput):
""" """
@@ -145,6 +221,12 @@ class TFSeq2SeqModelOutput(ModelOutput):
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model. Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
@@ -164,6 +246,7 @@ class TFSeq2SeqModelOutput(ModelOutput):
past_key_values: Optional[List[tf.Tensor]] = None past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_attentions: Optional[Tuple[tf.Tensor]] = None
@@ -290,6 +373,12 @@ class TFSeq2SeqLMOutput(ModelOutput):
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model. Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
@@ -310,6 +399,7 @@ class TFSeq2SeqLMOutput(ModelOutput):
past_key_values: Optional[List[tf.Tensor]] = None past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_attentions: Optional[Tuple[tf.Tensor]] = None

View File

@@ -30,7 +30,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
@@ -365,7 +365,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@@ -379,8 +379,8 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)` `(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(encoder_attention_heads,)` `(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
@@ -401,16 +401,17 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
@@ -432,6 +433,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_weights,
present_key_value, present_key_value,
) )
@@ -572,7 +574,7 @@ BART_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
@@ -580,6 +582,12 @@ BART_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@@ -677,7 +685,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
@@ -814,7 +822,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
@@ -856,14 +864,13 @@ class TFBartDecoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -894,7 +901,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
@@ -949,16 +956,18 @@ class TFBartDecoder(tf.keras.layers.Layer):
# decoder layers # decoder layers
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () if inputs["output_attentions"] else None all_self_attns = () if inputs["output_attentions"] else None
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): for attn_mask in ["head_mask", "cross_attn_head_mask"]:
if inputs[attn_mask] is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(inputs[attn_mask])[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
@@ -973,14 +982,14 @@ class TFBartDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["encoder_head_mask"] is not None if inputs["cross_attn_head_mask"] is not None
else None, else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
@@ -991,23 +1000,30 @@ class TFBartDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns = list(all_self_attns) all_self_attns = list(all_self_attns)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
if inputs["use_cache"]: if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values) present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
@@ -1057,6 +1073,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1077,6 +1094,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1131,7 +1149,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"], head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"], cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
@@ -1149,6 +1167,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
@@ -1189,6 +1208,7 @@ class TFBartModel(TFBartPretrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1209,6 +1229,7 @@ class TFBartModel(TFBartPretrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1228,6 +1249,7 @@ class TFBartModel(TFBartPretrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
@@ -1245,6 +1267,7 @@ class TFBartModel(TFBartPretrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -1253,6 +1276,7 @@ class TFBartModel(TFBartPretrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
@@ -1309,6 +1333,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1339,6 +1364,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1372,6 +1398,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@@ -1394,6 +1421,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
@@ -1403,6 +1431,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -1411,6 +1440,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,

View File

@@ -32,7 +32,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
@@ -370,7 +370,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@@ -384,8 +384,8 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)` `(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(encoder_attention_heads,)` `(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
@@ -406,17 +406,18 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
@@ -437,6 +438,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_weights,
present_key_value, present_key_value,
) )
@@ -569,7 +571,7 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
@@ -577,6 +579,12 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@@ -674,7 +682,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
@@ -818,7 +826,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
@@ -860,14 +868,13 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -904,7 +911,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
@@ -957,18 +964,20 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=inputs["training"])
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): for attn_mask in ["head_mask", "cross_attn_head_mask"]:
if inputs[attn_mask] is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(inputs[attn_mask])[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
@@ -982,14 +991,14 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["encoder_head_mask"] is not None if inputs["cross_attn_head_mask"] is not None
else None, else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
@@ -1000,25 +1009,32 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else:
all_hidden_states = None
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None if inputs["output_attentions"]:
all_self_attns = list(all_self_attns)
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
@@ -1065,6 +1081,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1085,6 +1102,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1131,7 +1149,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"], head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"], cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
@@ -1149,6 +1167,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
@@ -1199,6 +1218,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1219,6 +1239,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1238,6 +1259,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
@@ -1256,6 +1278,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -1264,6 +1287,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
@@ -1331,6 +1355,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1361,6 +1386,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1394,6 +1420,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@@ -1416,6 +1443,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
@@ -1426,6 +1454,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -1434,6 +1463,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,

View File

@@ -30,7 +30,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
@@ -369,7 +369,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@@ -383,8 +383,8 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)` `(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(encoder_attention_heads,)` `(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
@@ -405,16 +405,17 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
@@ -436,6 +437,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_weights,
present_key_value, present_key_value,
) )
@@ -574,7 +576,7 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
@@ -582,6 +584,12 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@@ -679,7 +687,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
@@ -823,7 +831,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
@@ -865,14 +873,13 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -909,7 +916,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
@@ -960,18 +967,20 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=inputs["training"])
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): for attn_mask in ["head_mask", "cross_attn_head_mask"]:
if inputs[attn_mask] is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(inputs[attn_mask])[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
@@ -985,14 +994,14 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["encoder_head_mask"] is not None if inputs["cross_attn_head_mask"] is not None
else None, else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
@@ -1003,23 +1012,30 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else:
all_hidden_states = None
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None if inputs["output_attentions"]:
all_self_attns = list(all_self_attns)
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
@@ -1066,6 +1082,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1086,6 +1103,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1132,7 +1150,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"], head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"], cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
@@ -1150,6 +1168,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
@@ -1187,6 +1206,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1207,6 +1227,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1226,6 +1247,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
@@ -1244,6 +1266,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -1252,6 +1275,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
@@ -1306,6 +1330,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1336,6 +1361,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1369,6 +1395,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@@ -1391,6 +1418,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
@@ -1401,6 +1429,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -1409,6 +1438,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,

View File

@@ -31,7 +31,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
@@ -408,7 +408,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@@ -422,8 +422,8 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)` `(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(encoder_attention_heads,)` `(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
@@ -444,16 +444,17 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
@@ -475,6 +476,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_weights,
present_key_value, present_key_value,
) )
@@ -603,7 +605,7 @@ MARIAN_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
@@ -611,6 +613,12 @@ MARIAN_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@@ -707,7 +715,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
@@ -848,7 +856,7 @@ class TFMarianDecoder(tf.keras.layers.Layer):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
@@ -890,14 +898,13 @@ class TFMarianDecoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -934,7 +941,7 @@ class TFMarianDecoder(tf.keras.layers.Layer):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
@@ -986,18 +993,20 @@ class TFMarianDecoder(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states + positions, training=inputs["training"]) hidden_states = self.dropout(hidden_states + positions, training=inputs["training"])
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): for attn_mask in ["head_mask", "cross_attn_head_mask"]:
if inputs[attn_mask] is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(inputs[attn_mask])[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
@@ -1011,14 +1020,14 @@ class TFMarianDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["encoder_head_mask"] is not None if inputs["cross_attn_head_mask"] is not None
else None, else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
@@ -1029,23 +1038,30 @@ class TFMarianDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else:
all_hidden_states = None
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None if inputs["output_attentions"]:
all_self_attns = list(all_self_attns)
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
@@ -1092,6 +1108,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1112,6 +1129,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1161,7 +1179,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"], head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"], cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
@@ -1179,6 +1197,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
@@ -1216,6 +1235,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1235,6 +1255,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
@@ -1255,6 +1276,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
@@ -1273,6 +1295,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -1281,6 +1304,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
@@ -1335,6 +1359,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1365,6 +1390,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1398,6 +1424,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@@ -1420,6 +1447,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
@@ -1430,6 +1458,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -1438,6 +1467,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,

View File

@@ -30,7 +30,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
@@ -368,7 +368,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@@ -382,8 +382,8 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)` `(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(encoder_attention_heads,)` `(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
@@ -404,17 +404,18 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
@@ -435,6 +436,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_weights,
present_key_value, present_key_value,
) )
@@ -547,7 +549,7 @@ MBART_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
@@ -555,6 +557,12 @@ MBART_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@@ -828,7 +836,7 @@ class TFMBartDecoder(tf.keras.layers.Layer):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
@@ -870,14 +878,13 @@ class TFMBartDecoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -914,7 +921,7 @@ class TFMBartDecoder(tf.keras.layers.Layer):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
@@ -967,18 +974,20 @@ class TFMBartDecoder(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=inputs["training"])
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): for attn_mask in ["head_mask", "cross_attn_head_mask"]:
if inputs[attn_mask] is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(inputs[attn_mask])[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
@@ -992,14 +1001,14 @@ class TFMBartDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["encoder_head_mask"] is not None if inputs["cross_attn_head_mask"] is not None
else None, else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
@@ -1010,25 +1019,32 @@ class TFMBartDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else:
all_hidden_states = None
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None if inputs["output_attentions"]:
all_self_attns = list(all_self_attns)
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
@@ -1075,6 +1091,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1095,6 +1112,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1147,7 +1165,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"], head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"], cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
@@ -1165,6 +1183,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
@@ -1202,6 +1221,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1222,6 +1242,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1241,6 +1262,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
@@ -1259,6 +1281,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -1267,6 +1290,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
@@ -1321,6 +1345,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1351,6 +1376,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1382,6 +1408,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@@ -1404,6 +1431,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
@@ -1414,6 +1442,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -1422,6 +1451,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,

View File

@@ -31,7 +31,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
@@ -409,7 +409,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@@ -423,8 +423,8 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)` `(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(encoder_attention_heads,)` `(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
@@ -445,17 +445,18 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
@@ -476,6 +477,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_weights,
present_key_value, present_key_value,
) )
@@ -603,7 +605,7 @@ PEGASUS_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
@@ -611,6 +613,12 @@ PEGASUS_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@@ -855,7 +863,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
@@ -897,14 +905,13 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
@@ -941,7 +948,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
@@ -993,18 +1000,20 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states + positions, training=inputs["training"]) hidden_states = self.dropout(hidden_states + positions, training=inputs["training"])
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): for attn_mask in ["head_mask", "cross_attn_head_mask"]:
if inputs[attn_mask] is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(inputs[attn_mask])[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
@@ -1018,14 +1027,14 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["encoder_head_mask"] is not None if inputs["cross_attn_head_mask"] is not None
else None, else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
@@ -1036,25 +1045,32 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else:
all_hidden_states = None
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None if inputs["output_attentions"]:
all_self_attns = list(all_self_attns)
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
@@ -1101,6 +1117,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1121,6 +1138,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1170,7 +1188,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"], head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"], cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
@@ -1188,6 +1206,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
@@ -1225,6 +1244,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1245,6 +1265,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1264,6 +1285,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
@@ -1282,6 +1304,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -1290,6 +1313,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
@@ -1344,6 +1368,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1374,6 +1399,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -1407,6 +1433,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@@ -1429,6 +1456,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
@@ -1439,6 +1467,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -1447,6 +1476,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,

View File

@@ -147,7 +147,6 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
return final_embeddings return final_embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer): class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs): def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
@@ -352,6 +351,7 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
return outputs return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs): def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
@@ -625,7 +625,6 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
base_model_prefix = "{{cookiecutter.lowercase_modelname}}" base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
@@ -885,6 +884,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings( @add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING """{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
) )
@@ -1728,16 +1728,18 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
""" """
Args: Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
""" """
residual = hidden_states residual = hidden_states
hidden_states, self_attn_weights, _ = self.self_attn( hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
@@ -1798,6 +1800,8 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@@ -1809,6 +1813,10 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
@@ -1821,6 +1829,7 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@@ -1828,15 +1837,17 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
@@ -1858,6 +1869,7 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_layer_head_mask,
present_key_value, present_key_value,
) )
@@ -1965,6 +1977,24 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
the right for denoising pre-training following the paper. the right for denoising pre-training following the paper.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@@ -2013,7 +2043,6 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding( self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
@@ -2034,6 +2063,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
input_ids=None, input_ids=None,
inputs_embeds=None, inputs_embeds=None,
attention_mask=None, attention_mask=None,
head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
@@ -2058,6 +2088,12 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
@@ -2082,6 +2118,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
@@ -2115,6 +2152,16 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers # encoder layers
for encoder_layer in self.layers: for encoder_layer in self.layers:
@@ -2125,7 +2172,11 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer(hidden_states, inputs["attention_mask"]) hidden_states, attn = encoder_layer(
hidden_states,
inputs["attention_mask"],
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_attentions += (attn,) all_attentions += (attn,)
@@ -2181,6 +2232,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
@@ -2218,6 +2271,18 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. decoding.
@@ -2252,6 +2317,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
@@ -2297,8 +2364,20 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
# decoder layers # decoder layers
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () if inputs["output_attentions"] else None all_self_attns = () if inputs["output_attentions"] else None
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None present_key_values = () if inputs["use_cache"] else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask in ["head_mask", "cross_attn_head_mask"]:
if inputs[attn_mask] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs[attn_mask])[0],
len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
@@ -2311,11 +2390,15 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["cross_attn_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
@@ -2325,23 +2408,30 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns = list(all_self_attns) all_self_attns = list(all_self_attns)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
if inputs["use_cache"]: if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values) present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
@tf.function @tf.function
@@ -2413,6 +2503,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -2431,6 +2524,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -2450,6 +2546,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder( inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
@@ -2472,6 +2569,8 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"], attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
@@ -2489,6 +2588,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
@@ -2524,6 +2624,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -2542,6 +2645,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -2559,6 +2665,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
@@ -2577,6 +2686,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -2585,6 +2695,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
@@ -2637,6 +2748,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
@@ -2672,6 +2786,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -2698,6 +2815,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@@ -2720,6 +2840,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
@@ -2730,6 +2851,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
@@ -2738,6 +2860,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,

View File

@@ -147,6 +147,7 @@ def prepare_bart_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -162,13 +163,16 @@ def prepare_bart_inputs_dict(
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }

View File

@@ -146,6 +146,7 @@ def prepare_blenderbot_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -161,6 +162,8 @@ def prepare_blenderbot_inputs_dict(
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
@@ -168,6 +171,7 @@ def prepare_blenderbot_inputs_dict(
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }

View File

@@ -146,6 +146,7 @@ def prepare_blenderbot_small_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -161,6 +162,8 @@ def prepare_blenderbot_small_inputs_dict(
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
@@ -168,6 +171,7 @@ def prepare_blenderbot_small_inputs_dict(
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }

View File

@@ -190,8 +190,12 @@ class TFModelTesterMixin:
"decoder_attention_mask", "decoder_attention_mask",
] ]
expected_arg_names.extend( expected_arg_names.extend(
["head_mask", "decoder_head_mask", "encoder_outputs"] ["head_mask", "decoder_head_mask"] if "head_mask" and "decoder_head_mask" in arg_names else []
if "head_mask" and "decoder_head_mask" in arg_names )
# Necessary to handle BART with newly added cross_attn_head_mask
expected_arg_names.extend(
["cross_attn_head_mask", "encoder_outputs"]
if "cross_attn_head_mask" in arg_names
else ["encoder_outputs"] else ["encoder_outputs"]
) )
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
@@ -512,6 +516,8 @@ class TFModelTesterMixin:
del inputs_dict["head_mask"] del inputs_dict["head_mask"]
if "decoder_head_mask" in inputs_dict: if "decoder_head_mask" in inputs_dict:
del inputs_dict["decoder_head_mask"] del inputs_dict["decoder_head_mask"]
if "cross_attn_head_mask" in inputs_dict:
del inputs_dict["cross_attn_head_mask"]
tf_main_layer_classes = set( tf_main_layer_classes = set(
module_member module_member
for model_class in self.all_model_classes for model_class in self.all_model_classes
@@ -639,7 +645,7 @@ class TFModelTesterMixin:
def check_decoder_attentions_output(outputs): def check_decoder_attentions_output(outputs):
out_len = len(outputs) out_len = len(outputs)
self.assertEqual(out_len % 2, 0) self.assertEqual(min(out_len % 2, out_len % 5), 0) # differentiation due to newly added cross_attentions
decoder_attentions = outputs.decoder_attentions decoder_attentions = outputs.decoder_attentions
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
@@ -733,6 +739,8 @@ class TFModelTesterMixin:
arg_names = [*signature.parameters.keys()] arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
inputs["decoder_head_mask"] = head_mask inputs["decoder_head_mask"] = head_mask
if "cross_attn_head_mask" in arg_names:
inputs["cross_attn_head_mask"] = head_mask
outputs = model(**inputs, return_dict=True) outputs = model(**inputs, return_dict=True)
@@ -757,6 +765,8 @@ class TFModelTesterMixin:
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
check_attentions_validity(outputs.encoder_attentions) check_attentions_validity(outputs.encoder_attentions)
check_attentions_validity(outputs.decoder_attentions) check_attentions_validity(outputs.decoder_attentions)
if "cross_attn_head_mask" in arg_names:
check_attentions_validity(outputs.cross_attentions)
else: else:
check_attentions_validity(outputs.attentions) check_attentions_validity(outputs.attentions)

View File

@@ -148,6 +148,7 @@ def prepare_marian_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -163,6 +164,8 @@ def prepare_marian_inputs_dict(
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
@@ -170,6 +173,7 @@ def prepare_marian_inputs_dict(
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }

View File

@@ -150,6 +150,7 @@ def prepare_mbart_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -165,13 +166,16 @@ def prepare_mbart_inputs_dict(
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }

View File

@@ -146,6 +146,7 @@ def prepare_pegasus_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -161,6 +162,8 @@ def prepare_pegasus_inputs_dict(
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
@@ -168,6 +171,7 @@ def prepare_pegasus_inputs_dict(
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }