TF BART models - Add cross_attentions to model output and fix cross-attention head masking (#10699)
* Add cross_attn_head_mask to BART * Fix cross_attentions in TFBart-like models * This commit enables returning of `cross_attentions` for TFBart-like models * It also fixes attention head masking in cross-attenion module * Update TF model templates * Fix missing , in TF model templates * Fix typo: congig -> config
This commit is contained in:
@@ -116,6 +116,82 @@ class TFBaseModelOutputWithPast(ModelOutput):
|
|||||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TFBaseModelOutputWithCrossAttentions(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for model's outputs, with potential hidden states and attentions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
hidden_states (:obj:`tuple(tf.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
|
weighted average in the cross-attention heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_hidden_state: tf.Tensor = None
|
||||||
|
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
cross_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
|
||||||
|
1, hidden_size)` is output.
|
||||||
|
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
|
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size,
|
||||||
|
num_heads, sequence_length, embed_size_per_head)`).
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||||
|
:obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
|
hidden_states (:obj:`tuple(tf.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
|
weighted average in the cross-attention heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_hidden_state: tf.Tensor = None
|
||||||
|
past_key_values: Optional[List[tf.Tensor]] = None
|
||||||
|
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
cross_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TFSeq2SeqModelOutput(ModelOutput):
|
class TFSeq2SeqModelOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@@ -145,6 +221,12 @@ class TFSeq2SeqModelOutput(ModelOutput):
|
|||||||
|
|
||||||
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads.
|
self-attention heads.
|
||||||
|
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
|
weighted average in the cross-attention heads.
|
||||||
encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||||
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
@@ -164,6 +246,7 @@ class TFSeq2SeqModelOutput(ModelOutput):
|
|||||||
past_key_values: Optional[List[tf.Tensor]] = None
|
past_key_values: Optional[List[tf.Tensor]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
cross_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
||||||
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
@@ -290,6 +373,12 @@ class TFSeq2SeqLMOutput(ModelOutput):
|
|||||||
|
|
||||||
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads.
|
self-attention heads.
|
||||||
|
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
|
weighted average in the cross-attention heads.
|
||||||
encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||||
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
@@ -310,6 +399,7 @@ class TFSeq2SeqLMOutput(ModelOutput):
|
|||||||
past_key_values: Optional[List[tf.Tensor]] = None
|
past_key_values: Optional[List[tf.Tensor]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
cross_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
||||||
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from ...file_utils import (
|
|||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
TFBaseModelOutput,
|
TFBaseModelOutput,
|
||||||
TFBaseModelOutputWithPast,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFSeq2SeqLMOutput,
|
TFSeq2SeqLMOutput,
|
||||||
TFSeq2SeqModelOutput,
|
TFSeq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
@@ -365,7 +365,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||||
training=False,
|
training=False,
|
||||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||||
@@ -379,8 +379,8 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||||
`(decoder_attention_heads,)`
|
`(decoder_attention_heads,)`
|
||||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
|
cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
|
||||||
`(encoder_attention_heads,)`
|
`(decoder_attention_heads,)`
|
||||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -401,16 +401,17 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# Cross-Attention Block
|
# Cross-Attention Block
|
||||||
cross_attn_present_key_value = None
|
cross_attn_present_key_value = None
|
||||||
|
cross_attn_weights = None
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
hidden_states, _, cross_attn_present_key_value = self.encoder_attn(
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
layer_head_mask=encoder_layer_head_mask,
|
layer_head_mask=cross_attn_layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
)
|
)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
@@ -432,6 +433,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self_attn_weights,
|
self_attn_weights,
|
||||||
|
cross_attn_weights,
|
||||||
present_key_value,
|
present_key_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -572,7 +574,7 @@ BART_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
@@ -580,6 +582,12 @@ BART_INPUTS_DOCSTRING = r"""
|
|||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the head is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||||
@@ -677,7 +685,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
@@ -814,7 +822,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_head_mask=None,
|
cross_attn_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
@@ -856,14 +864,13 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
@@ -894,7 +901,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_head_mask=encoder_head_mask,
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -949,16 +956,18 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||||
all_self_attns = () if inputs["output_attentions"] else None
|
all_self_attns = () if inputs["output_attentions"] else None
|
||||||
|
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
|
||||||
present_key_values = () if inputs["use_cache"] else None
|
present_key_values = () if inputs["use_cache"] else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
# The tf.debugging asserts are not compliant with XLA then they
|
||||||
# have to be disabled in other modes than eager.
|
# have to be disabled in other modes than eager.
|
||||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
for attn_mask in ["head_mask", "cross_attn_head_mask"]:
|
||||||
|
if inputs[attn_mask] is not None and tf.executing_eagerly():
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(inputs["head_mask"])[0],
|
shape_list(inputs[attn_mask])[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
@@ -973,14 +982,14 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
||||||
|
|
||||||
hidden_states, layer_self_attn, present_key_value = decoder_layer(
|
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
|
||||||
if inputs["encoder_head_mask"] is not None
|
if inputs["cross_attn_head_mask"] is not None
|
||||||
else None,
|
else None,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
)
|
)
|
||||||
@@ -991,23 +1000,30 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_self_attn,)
|
||||||
|
|
||||||
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns += (layer_cross_attn,)
|
||||||
|
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_self_attns = list(all_self_attns)
|
all_self_attns = list(all_self_attns)
|
||||||
|
|
||||||
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns = list(all_cross_attns)
|
||||||
|
|
||||||
if inputs["use_cache"]:
|
if inputs["use_cache"]:
|
||||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||||
else:
|
else:
|
||||||
return TFBaseModelOutputWithPast(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=present_key_values,
|
past_key_values=present_key_values,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
|
cross_attentions=all_cross_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1057,6 +1073,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1077,6 +1094,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1131,7 +1149,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||||
encoder_attention_mask=inputs["attention_mask"],
|
encoder_attention_mask=inputs["attention_mask"],
|
||||||
head_mask=inputs["decoder_head_mask"],
|
head_mask=inputs["decoder_head_mask"],
|
||||||
encoder_head_mask=inputs["head_mask"],
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
use_cache=inputs["use_cache"],
|
use_cache=inputs["use_cache"],
|
||||||
@@ -1149,6 +1167,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
|
|||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||||
@@ -1189,6 +1208,7 @@ class TFBartModel(TFBartPretrainedModel):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1209,6 +1229,7 @@ class TFBartModel(TFBartPretrainedModel):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1228,6 +1249,7 @@ class TFBartModel(TFBartPretrainedModel):
|
|||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
decoder_head_mask=inputs["decoder_head_mask"],
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
encoder_outputs=inputs["encoder_outputs"],
|
encoder_outputs=inputs["encoder_outputs"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
@@ -1245,6 +1267,7 @@ class TFBartModel(TFBartPretrainedModel):
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -1253,6 +1276,7 @@ class TFBartModel(TFBartPretrainedModel):
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
@@ -1309,6 +1333,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1339,6 +1364,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1372,6 +1398,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
decoder_head_mask=inputs["decoder_head_mask"],
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
@@ -1394,6 +1421,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||||||
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
||||||
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
||||||
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
||||||
|
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
|
||||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
||||||
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
||||||
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
||||||
@@ -1403,6 +1431,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -1411,6 +1440,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from ...file_utils import (
|
|||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
TFBaseModelOutput,
|
TFBaseModelOutput,
|
||||||
TFBaseModelOutputWithPast,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFSeq2SeqLMOutput,
|
TFSeq2SeqLMOutput,
|
||||||
TFSeq2SeqModelOutput,
|
TFSeq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
@@ -370,7 +370,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||||
training=False,
|
training=False,
|
||||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||||
@@ -384,8 +384,8 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
|
|||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||||
`(decoder_attention_heads,)`
|
`(decoder_attention_heads,)`
|
||||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
|
cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
|
||||||
`(encoder_attention_heads,)`
|
`(decoder_attention_heads,)`
|
||||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -406,17 +406,18 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# Cross-Attention Block
|
# Cross-Attention Block
|
||||||
cross_attn_present_key_value = None
|
cross_attn_present_key_value = None
|
||||||
|
cross_attn_weights = None
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||||
|
|
||||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
hidden_states, _, cross_attn_present_key_value = self.encoder_attn(
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
layer_head_mask=encoder_layer_head_mask,
|
layer_head_mask=cross_attn_layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
)
|
)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
@@ -437,6 +438,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
|
|||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self_attn_weights,
|
self_attn_weights,
|
||||||
|
cross_attn_weights,
|
||||||
present_key_value,
|
present_key_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -569,7 +571,7 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
@@ -577,6 +579,12 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
|
|||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the head is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||||
@@ -674,7 +682,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
|||||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
@@ -818,7 +826,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_head_mask=None,
|
cross_attn_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
@@ -860,14 +868,13 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
|||||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
@@ -904,7 +911,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_head_mask=encoder_head_mask,
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -957,18 +964,20 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
|||||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = ()
|
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||||
all_self_attns = ()
|
all_self_attns = () if inputs["output_attentions"] else None
|
||||||
present_key_values = ()
|
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
|
||||||
|
present_key_values = () if inputs["use_cache"] else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
# The tf.debugging asserts are not compliant with XLA then they
|
||||||
# have to be disabled in other modes than eager.
|
# have to be disabled in other modes than eager.
|
||||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
for attn_mask in ["head_mask", "cross_attn_head_mask"]:
|
||||||
|
if inputs[attn_mask] is not None and tf.executing_eagerly():
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(inputs["head_mask"])[0],
|
shape_list(inputs[attn_mask])[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
@@ -982,14 +991,14 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
||||||
|
|
||||||
hidden_states, layer_self_attn, present_key_value = decoder_layer(
|
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
|
||||||
if inputs["encoder_head_mask"] is not None
|
if inputs["cross_attn_head_mask"] is not None
|
||||||
else None,
|
else None,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
)
|
)
|
||||||
@@ -1000,25 +1009,32 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
|||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_self_attn,)
|
||||||
|
|
||||||
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns += (layer_cross_attn,)
|
||||||
|
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
else:
|
|
||||||
all_hidden_states = None
|
|
||||||
|
|
||||||
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
|
if inputs["output_attentions"]:
|
||||||
|
all_self_attns = list(all_self_attns)
|
||||||
|
|
||||||
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns = list(all_cross_attns)
|
||||||
|
|
||||||
|
if inputs["use_cache"]:
|
||||||
|
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||||
else:
|
else:
|
||||||
return TFBaseModelOutputWithPast(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=present_key_values,
|
past_key_values=present_key_values,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
|
cross_attentions=all_cross_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1065,6 +1081,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1085,6 +1102,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1131,7 +1149,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||||
encoder_attention_mask=inputs["attention_mask"],
|
encoder_attention_mask=inputs["attention_mask"],
|
||||||
head_mask=inputs["decoder_head_mask"],
|
head_mask=inputs["decoder_head_mask"],
|
||||||
encoder_head_mask=inputs["head_mask"],
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
use_cache=inputs["use_cache"],
|
use_cache=inputs["use_cache"],
|
||||||
@@ -1149,6 +1167,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
|
|||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||||
@@ -1199,6 +1218,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1219,6 +1239,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1238,6 +1259,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
|||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
decoder_head_mask=inputs["decoder_head_mask"],
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
encoder_outputs=inputs["encoder_outputs"],
|
encoder_outputs=inputs["encoder_outputs"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
@@ -1256,6 +1278,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -1264,6 +1287,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
@@ -1331,6 +1355,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1361,6 +1386,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1394,6 +1420,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
|||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
decoder_head_mask=inputs["decoder_head_mask"],
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
@@ -1416,6 +1443,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
|||||||
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
||||||
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
||||||
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
||||||
|
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
|
||||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
||||||
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
||||||
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
||||||
@@ -1426,6 +1454,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -1434,6 +1463,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from ...file_utils import (
|
|||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
TFBaseModelOutput,
|
TFBaseModelOutput,
|
||||||
TFBaseModelOutputWithPast,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFSeq2SeqLMOutput,
|
TFSeq2SeqLMOutput,
|
||||||
TFSeq2SeqModelOutput,
|
TFSeq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
@@ -369,7 +369,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||||
training=False,
|
training=False,
|
||||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||||
@@ -383,8 +383,8 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
|||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||||
`(decoder_attention_heads,)`
|
`(decoder_attention_heads,)`
|
||||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
|
cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
|
||||||
`(encoder_attention_heads,)`
|
`(decoder_attention_heads,)`
|
||||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -405,16 +405,17 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# Cross-Attention Block
|
# Cross-Attention Block
|
||||||
cross_attn_present_key_value = None
|
cross_attn_present_key_value = None
|
||||||
|
cross_attn_weights = None
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
hidden_states, _, cross_attn_present_key_value = self.encoder_attn(
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
layer_head_mask=encoder_layer_head_mask,
|
layer_head_mask=cross_attn_layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
)
|
)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
@@ -436,6 +437,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
|||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self_attn_weights,
|
self_attn_weights,
|
||||||
|
cross_attn_weights,
|
||||||
present_key_value,
|
present_key_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -574,7 +576,7 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
@@ -582,6 +584,12 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
|
|||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the head is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||||
@@ -679,7 +687,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
|||||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
@@ -823,7 +831,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_head_mask=None,
|
cross_attn_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
@@ -865,14 +873,13 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
|||||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
@@ -909,7 +916,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_head_mask=encoder_head_mask,
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -960,18 +967,20 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
|||||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = ()
|
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||||
all_self_attns = ()
|
all_self_attns = () if inputs["output_attentions"] else None
|
||||||
present_key_values = ()
|
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
|
||||||
|
present_key_values = () if inputs["use_cache"] else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
# The tf.debugging asserts are not compliant with XLA then they
|
||||||
# have to be disabled in other modes than eager.
|
# have to be disabled in other modes than eager.
|
||||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
for attn_mask in ["head_mask", "cross_attn_head_mask"]:
|
||||||
|
if inputs[attn_mask] is not None and tf.executing_eagerly():
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(inputs["head_mask"])[0],
|
shape_list(inputs[attn_mask])[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
@@ -985,14 +994,14 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
||||||
|
|
||||||
hidden_states, layer_self_attn, present_key_value = decoder_layer(
|
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
|
||||||
if inputs["encoder_head_mask"] is not None
|
if inputs["cross_attn_head_mask"] is not None
|
||||||
else None,
|
else None,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
)
|
)
|
||||||
@@ -1003,23 +1012,30 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
|||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_self_attn,)
|
||||||
|
|
||||||
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns += (layer_cross_attn,)
|
||||||
|
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
else:
|
|
||||||
all_hidden_states = None
|
|
||||||
|
|
||||||
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
|
if inputs["output_attentions"]:
|
||||||
|
all_self_attns = list(all_self_attns)
|
||||||
|
|
||||||
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns = list(all_cross_attns)
|
||||||
|
|
||||||
|
if inputs["use_cache"]:
|
||||||
|
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||||
else:
|
else:
|
||||||
return TFBaseModelOutputWithPast(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=present_key_values,
|
past_key_values=present_key_values,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
|
cross_attentions=all_cross_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1066,6 +1082,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1086,6 +1103,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1132,7 +1150,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||||
encoder_attention_mask=inputs["attention_mask"],
|
encoder_attention_mask=inputs["attention_mask"],
|
||||||
head_mask=inputs["decoder_head_mask"],
|
head_mask=inputs["decoder_head_mask"],
|
||||||
encoder_head_mask=inputs["head_mask"],
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
use_cache=inputs["use_cache"],
|
use_cache=inputs["use_cache"],
|
||||||
@@ -1150,6 +1168,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
|||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||||
@@ -1187,6 +1206,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1207,6 +1227,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1226,6 +1247,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
|||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
decoder_head_mask=inputs["decoder_head_mask"],
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
encoder_outputs=inputs["encoder_outputs"],
|
encoder_outputs=inputs["encoder_outputs"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
@@ -1244,6 +1266,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -1252,6 +1275,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
@@ -1306,6 +1330,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1336,6 +1361,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1369,6 +1395,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
decoder_head_mask=inputs["decoder_head_mask"],
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
@@ -1391,6 +1418,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||||||
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
||||||
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
||||||
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
||||||
|
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
|
||||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
||||||
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
||||||
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
||||||
@@ -1401,6 +1429,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -1409,6 +1438,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from ...file_utils import (
|
|||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
TFBaseModelOutput,
|
TFBaseModelOutput,
|
||||||
TFBaseModelOutputWithPast,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFSeq2SeqLMOutput,
|
TFSeq2SeqLMOutput,
|
||||||
TFSeq2SeqModelOutput,
|
TFSeq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
@@ -408,7 +408,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||||
training=False,
|
training=False,
|
||||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||||
@@ -422,8 +422,8 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
|
|||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||||
`(decoder_attention_heads,)`
|
`(decoder_attention_heads,)`
|
||||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
|
cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
|
||||||
`(encoder_attention_heads,)`
|
`(decoder_attention_heads,)`
|
||||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -444,16 +444,17 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# Cross-Attention Block
|
# Cross-Attention Block
|
||||||
cross_attn_present_key_value = None
|
cross_attn_present_key_value = None
|
||||||
|
cross_attn_weights = None
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
hidden_states, _, cross_attn_present_key_value = self.encoder_attn(
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
layer_head_mask=encoder_layer_head_mask,
|
layer_head_mask=cross_attn_layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
)
|
)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
@@ -475,6 +476,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
|
|||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self_attn_weights,
|
self_attn_weights,
|
||||||
|
cross_attn_weights,
|
||||||
present_key_value,
|
present_key_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -603,7 +605,7 @@ MARIAN_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
@@ -611,6 +613,12 @@ MARIAN_INPUTS_DOCSTRING = r"""
|
|||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the head is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||||
@@ -707,7 +715,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
|
|||||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
@@ -848,7 +856,7 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_head_mask=None,
|
cross_attn_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
@@ -890,14 +898,13 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
|||||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
@@ -934,7 +941,7 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_head_mask=encoder_head_mask,
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -986,18 +993,20 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
|||||||
hidden_states = self.dropout(hidden_states + positions, training=inputs["training"])
|
hidden_states = self.dropout(hidden_states + positions, training=inputs["training"])
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = ()
|
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||||
all_self_attns = ()
|
all_self_attns = () if inputs["output_attentions"] else None
|
||||||
present_key_values = ()
|
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
|
||||||
|
present_key_values = () if inputs["use_cache"] else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
# The tf.debugging asserts are not compliant with XLA then they
|
||||||
# have to be disabled in other modes than eager.
|
# have to be disabled in other modes than eager.
|
||||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
for attn_mask in ["head_mask", "cross_attn_head_mask"]:
|
||||||
|
if inputs[attn_mask] is not None and tf.executing_eagerly():
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(inputs["head_mask"])[0],
|
shape_list(inputs[attn_mask])[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
@@ -1011,14 +1020,14 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
||||||
|
|
||||||
hidden_states, layer_self_attn, present_key_value = decoder_layer(
|
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
|
||||||
if inputs["encoder_head_mask"] is not None
|
if inputs["cross_attn_head_mask"] is not None
|
||||||
else None,
|
else None,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
)
|
)
|
||||||
@@ -1029,23 +1038,30 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
|||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_self_attn,)
|
||||||
|
|
||||||
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns += (layer_cross_attn,)
|
||||||
|
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
else:
|
|
||||||
all_hidden_states = None
|
|
||||||
|
|
||||||
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
|
if inputs["output_attentions"]:
|
||||||
|
all_self_attns = list(all_self_attns)
|
||||||
|
|
||||||
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns = list(all_cross_attns)
|
||||||
|
|
||||||
|
if inputs["use_cache"]:
|
||||||
|
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||||
else:
|
else:
|
||||||
return TFBaseModelOutputWithPast(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=present_key_values,
|
past_key_values=present_key_values,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
|
cross_attentions=all_cross_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1092,6 +1108,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1112,6 +1129,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1161,7 +1179,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||||
encoder_attention_mask=inputs["attention_mask"],
|
encoder_attention_mask=inputs["attention_mask"],
|
||||||
head_mask=inputs["decoder_head_mask"],
|
head_mask=inputs["decoder_head_mask"],
|
||||||
encoder_head_mask=inputs["head_mask"],
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
use_cache=inputs["use_cache"],
|
use_cache=inputs["use_cache"],
|
||||||
@@ -1179,6 +1197,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
|
|||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||||
@@ -1216,6 +1235,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1235,6 +1255,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
@@ -1255,6 +1276,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
|
|||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
decoder_head_mask=inputs["decoder_head_mask"],
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
encoder_outputs=inputs["encoder_outputs"],
|
encoder_outputs=inputs["encoder_outputs"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
@@ -1273,6 +1295,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -1281,6 +1304,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
@@ -1335,6 +1359,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1365,6 +1390,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1398,6 +1424,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
decoder_head_mask=inputs["decoder_head_mask"],
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
@@ -1420,6 +1447,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
||||||
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
||||||
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
||||||
|
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
|
||||||
encoder_last_hidden_state=outputs.last_hidden_state, # index 0 of encoder outputs
|
encoder_last_hidden_state=outputs.last_hidden_state, # index 0 of encoder outputs
|
||||||
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
||||||
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
||||||
@@ -1430,6 +1458,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -1438,6 +1467,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from ...file_utils import (
|
|||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
TFBaseModelOutput,
|
TFBaseModelOutput,
|
||||||
TFBaseModelOutputWithPast,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFSeq2SeqLMOutput,
|
TFSeq2SeqLMOutput,
|
||||||
TFSeq2SeqModelOutput,
|
TFSeq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
@@ -368,7 +368,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||||
training=False,
|
training=False,
|
||||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||||
@@ -382,8 +382,8 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||||
`(decoder_attention_heads,)`
|
`(decoder_attention_heads,)`
|
||||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
|
cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
|
||||||
`(encoder_attention_heads,)`
|
`(decoder_attention_heads,)`
|
||||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -404,17 +404,18 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# Cross-Attention Block
|
# Cross-Attention Block
|
||||||
cross_attn_present_key_value = None
|
cross_attn_present_key_value = None
|
||||||
|
cross_attn_weights = None
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||||
|
|
||||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
hidden_states, _, cross_attn_present_key_value = self.encoder_attn(
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
layer_head_mask=encoder_layer_head_mask,
|
layer_head_mask=cross_attn_layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
)
|
)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
@@ -435,6 +436,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self_attn_weights,
|
self_attn_weights,
|
||||||
|
cross_attn_weights,
|
||||||
present_key_value,
|
present_key_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -547,7 +549,7 @@ MBART_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
@@ -555,6 +557,12 @@ MBART_INPUTS_DOCSTRING = r"""
|
|||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the head is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||||
@@ -828,7 +836,7 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_head_mask=None,
|
cross_attn_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
@@ -870,14 +878,13 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
|||||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
@@ -914,7 +921,7 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_head_mask=encoder_head_mask,
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -967,18 +974,20 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
|||||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = ()
|
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||||
all_self_attns = ()
|
all_self_attns = () if inputs["output_attentions"] else None
|
||||||
present_key_values = ()
|
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
|
||||||
|
present_key_values = () if inputs["use_cache"] else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
# The tf.debugging asserts are not compliant with XLA then they
|
||||||
# have to be disabled in other modes than eager.
|
# have to be disabled in other modes than eager.
|
||||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
for attn_mask in ["head_mask", "cross_attn_head_mask"]:
|
||||||
|
if inputs[attn_mask] is not None and tf.executing_eagerly():
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(inputs["head_mask"])[0],
|
shape_list(inputs[attn_mask])[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
@@ -992,14 +1001,14 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
||||||
|
|
||||||
hidden_states, layer_self_attn, present_key_value = decoder_layer(
|
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
|
||||||
if inputs["encoder_head_mask"] is not None
|
if inputs["cross_attn_head_mask"] is not None
|
||||||
else None,
|
else None,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
)
|
)
|
||||||
@@ -1010,25 +1019,32 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
|||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_self_attn,)
|
||||||
|
|
||||||
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns += (layer_cross_attn,)
|
||||||
|
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
else:
|
|
||||||
all_hidden_states = None
|
|
||||||
|
|
||||||
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
|
if inputs["output_attentions"]:
|
||||||
|
all_self_attns = list(all_self_attns)
|
||||||
|
|
||||||
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns = list(all_cross_attns)
|
||||||
|
|
||||||
|
if inputs["use_cache"]:
|
||||||
|
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||||
else:
|
else:
|
||||||
return TFBaseModelOutputWithPast(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=present_key_values,
|
past_key_values=present_key_values,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
|
cross_attentions=all_cross_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1075,6 +1091,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1095,6 +1112,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1147,7 +1165,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||||
encoder_attention_mask=inputs["attention_mask"],
|
encoder_attention_mask=inputs["attention_mask"],
|
||||||
head_mask=inputs["decoder_head_mask"],
|
head_mask=inputs["decoder_head_mask"],
|
||||||
encoder_head_mask=inputs["head_mask"],
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
use_cache=inputs["use_cache"],
|
use_cache=inputs["use_cache"],
|
||||||
@@ -1165,6 +1183,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
|
|||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||||
@@ -1202,6 +1221,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1222,6 +1242,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1241,6 +1262,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
|||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
decoder_head_mask=inputs["decoder_head_mask"],
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
encoder_outputs=inputs["encoder_outputs"],
|
encoder_outputs=inputs["encoder_outputs"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
@@ -1259,6 +1281,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -1267,6 +1290,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
@@ -1321,6 +1345,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1351,6 +1376,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1382,6 +1408,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
|||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
decoder_head_mask=inputs["decoder_head_mask"],
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
@@ -1404,6 +1431,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
|||||||
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
||||||
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
||||||
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
||||||
|
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
|
||||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
||||||
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
||||||
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
||||||
@@ -1414,6 +1442,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -1422,6 +1451,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from ...file_utils import (
|
|||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
TFBaseModelOutput,
|
TFBaseModelOutput,
|
||||||
TFBaseModelOutputWithPast,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFSeq2SeqLMOutput,
|
TFSeq2SeqLMOutput,
|
||||||
TFSeq2SeqModelOutput,
|
TFSeq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
@@ -409,7 +409,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||||
training=False,
|
training=False,
|
||||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||||
@@ -423,8 +423,8 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
|
|||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||||
`(decoder_attention_heads,)`
|
`(decoder_attention_heads,)`
|
||||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
|
cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
|
||||||
`(encoder_attention_heads,)`
|
`(decoder_attention_heads,)`
|
||||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -445,17 +445,18 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# Cross-Attention Block
|
# Cross-Attention Block
|
||||||
cross_attn_present_key_value = None
|
cross_attn_present_key_value = None
|
||||||
|
cross_attn_weights = None
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||||
|
|
||||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
hidden_states, _, cross_attn_present_key_value = self.encoder_attn(
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
layer_head_mask=encoder_layer_head_mask,
|
layer_head_mask=cross_attn_layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
)
|
)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
@@ -476,6 +477,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
|
|||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self_attn_weights,
|
self_attn_weights,
|
||||||
|
cross_attn_weights,
|
||||||
present_key_value,
|
present_key_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -603,7 +605,7 @@ PEGASUS_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
@@ -611,6 +613,12 @@ PEGASUS_INPUTS_DOCSTRING = r"""
|
|||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the head is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||||
@@ -855,7 +863,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_head_mask=None,
|
cross_attn_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
@@ -897,14 +905,13 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
|||||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the heas is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
@@ -941,7 +948,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_head_mask=encoder_head_mask,
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -993,18 +1000,20 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
|||||||
hidden_states = self.dropout(hidden_states + positions, training=inputs["training"])
|
hidden_states = self.dropout(hidden_states + positions, training=inputs["training"])
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = ()
|
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||||
all_self_attns = ()
|
all_self_attns = () if inputs["output_attentions"] else None
|
||||||
present_key_values = ()
|
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
|
||||||
|
present_key_values = () if inputs["use_cache"] else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
# The tf.debugging asserts are not compliant with XLA then they
|
||||||
# have to be disabled in other modes than eager.
|
# have to be disabled in other modes than eager.
|
||||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
for attn_mask in ["head_mask", "cross_attn_head_mask"]:
|
||||||
|
if inputs[attn_mask] is not None and tf.executing_eagerly():
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(inputs["head_mask"])[0],
|
shape_list(inputs[attn_mask])[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
@@ -1018,14 +1027,14 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
||||||
|
|
||||||
hidden_states, layer_self_attn, present_key_value = decoder_layer(
|
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
|
||||||
if inputs["encoder_head_mask"] is not None
|
if inputs["cross_attn_head_mask"] is not None
|
||||||
else None,
|
else None,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
)
|
)
|
||||||
@@ -1036,25 +1045,32 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
|||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_self_attn,)
|
||||||
|
|
||||||
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns += (layer_cross_attn,)
|
||||||
|
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
else:
|
|
||||||
all_hidden_states = None
|
|
||||||
|
|
||||||
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
|
if inputs["output_attentions"]:
|
||||||
|
all_self_attns = list(all_self_attns)
|
||||||
|
|
||||||
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns = list(all_cross_attns)
|
||||||
|
|
||||||
|
if inputs["use_cache"]:
|
||||||
|
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||||
else:
|
else:
|
||||||
return TFBaseModelOutputWithPast(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=present_key_values,
|
past_key_values=present_key_values,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
|
cross_attentions=all_cross_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1101,6 +1117,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1121,6 +1138,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1170,7 +1188,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||||
encoder_attention_mask=inputs["attention_mask"],
|
encoder_attention_mask=inputs["attention_mask"],
|
||||||
head_mask=inputs["decoder_head_mask"],
|
head_mask=inputs["decoder_head_mask"],
|
||||||
encoder_head_mask=inputs["head_mask"],
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
use_cache=inputs["use_cache"],
|
use_cache=inputs["use_cache"],
|
||||||
@@ -1188,6 +1206,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
|
|||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||||
@@ -1225,6 +1244,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1245,6 +1265,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1264,6 +1285,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
|
|||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
decoder_head_mask=inputs["decoder_head_mask"],
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
encoder_outputs=inputs["encoder_outputs"],
|
encoder_outputs=inputs["encoder_outputs"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
@@ -1282,6 +1304,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -1290,6 +1313,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
@@ -1344,6 +1368,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1374,6 +1399,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
decoder_head_mask=decoder_head_mask,
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1407,6 +1433,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
|||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
head_mask=inputs["head_mask"],
|
head_mask=inputs["head_mask"],
|
||||||
decoder_head_mask=inputs["decoder_head_mask"],
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
@@ -1429,6 +1456,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
|||||||
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
||||||
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
||||||
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
||||||
|
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
|
||||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
||||||
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
||||||
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
||||||
@@ -1439,6 +1467,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -1447,6 +1476,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
|
|||||||
@@ -147,7 +147,6 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
|||||||
return final_embeddings
|
return final_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
|
||||||
class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer):
|
class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer):
|
||||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
||||||
@@ -352,6 +351,7 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->{{cookiecutter.camelcase_modelname}}
|
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->{{cookiecutter.camelcase_modelname}}
|
||||||
class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
||||||
@@ -625,7 +625,6 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
|||||||
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
|
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
|
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
|
||||||
|
|
||||||
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
|
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
|
||||||
@@ -885,6 +884,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
|||||||
|
|
||||||
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
|
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
|
||||||
)
|
)
|
||||||
@@ -1728,16 +1728,18 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
|
|||||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||||
|
|
||||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
|
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
attention_mask (:obj:`tf.Tensor`): attention mask of size
|
attention_mask (:obj:`tf.Tensor`): attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||||
|
`(encoder_attention_heads,)`
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||||
hidden_states=hidden_states, attention_mask=attention_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
# The tf.debugging asserts are not compliant with XLA then they
|
||||||
@@ -1798,6 +1800,8 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
|
|||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
|
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||||
training=False,
|
training=False,
|
||||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||||
@@ -1809,6 +1813,10 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
|
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||||
|
`(decoder_attention_heads,)`
|
||||||
|
cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
|
||||||
|
`(decoder_attention_heads,)`
|
||||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -1821,6 +1829,7 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
past_key_value=self_attn_past_key_value,
|
past_key_value=self_attn_past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
)
|
)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -1828,15 +1837,17 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# Cross-Attention Block
|
# Cross-Attention Block
|
||||||
cross_attn_present_key_value = None
|
cross_attn_present_key_value = None
|
||||||
|
cross_attn_weights = None
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
hidden_states, _, cross_attn_present_key_value = self.encoder_attn(
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=cross_attn_layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
)
|
)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
@@ -1858,6 +1869,7 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
|
|||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self_attn_weights,
|
self_attn_weights,
|
||||||
|
cross_attn_layer_head_mask,
|
||||||
present_key_value,
|
present_key_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1965,6 +1977,24 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
|||||||
the right for denoising pre-training following the paper.
|
the right for denoising pre-training following the paper.
|
||||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||||
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
||||||
|
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||||
@@ -2013,7 +2043,6 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||||||
self.max_source_positions = config.max_position_embeddings
|
self.max_source_positions = config.max_position_embeddings
|
||||||
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
|
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
|
||||||
|
|
||||||
|
|
||||||
self.embed_tokens = embed_tokens
|
self.embed_tokens = embed_tokens
|
||||||
self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
|
self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
|
||||||
config.max_position_embeddings,
|
config.max_position_embeddings,
|
||||||
@@ -2034,6 +2063,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||||||
input_ids=None,
|
input_ids=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -2058,6 +2088,12 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||||
@@ -2082,6 +2118,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||||||
config=self.config,
|
config=self.config,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -2115,6 +2152,16 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||||||
encoder_states = () if inputs["output_hidden_states"] else None
|
encoder_states = () if inputs["output_hidden_states"] else None
|
||||||
all_attentions = () if inputs["output_attentions"] else None
|
all_attentions = () if inputs["output_attentions"] else None
|
||||||
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
# The tf.debugging asserts are not compliant with XLA then they
|
||||||
|
# have to be disabled in other modes than eager.
|
||||||
|
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||||
|
tf.debugging.assert_equal(
|
||||||
|
shape_list(inputs["head_mask"])[0],
|
||||||
|
len(self.layers),
|
||||||
|
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||||
|
)
|
||||||
|
|
||||||
# encoder layers
|
# encoder layers
|
||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
|
|
||||||
@@ -2125,7 +2172,11 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||||
continue
|
continue
|
||||||
|
|
||||||
hidden_states, attn = encoder_layer(hidden_states, inputs["attention_mask"])
|
hidden_states, attn = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
inputs["attention_mask"],
|
||||||
|
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_attentions += (attn,)
|
all_attentions += (attn,)
|
||||||
@@ -2181,6 +2232,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
@@ -2218,6 +2271,18 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
decoding.
|
decoding.
|
||||||
@@ -2252,6 +2317,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -2297,8 +2364,20 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||||
all_self_attns = () if inputs["output_attentions"] else None
|
all_self_attns = () if inputs["output_attentions"] else None
|
||||||
|
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
|
||||||
present_key_values = () if inputs["use_cache"] else None
|
present_key_values = () if inputs["use_cache"] else None
|
||||||
|
|
||||||
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
|
# The tf.debugging asserts are not compliant with XLA then they
|
||||||
|
# have to be disabled in other modes than eager.
|
||||||
|
for attn_mask in ["head_mask", "cross_attn_head_mask"]:
|
||||||
|
if inputs[attn_mask] is not None and tf.executing_eagerly():
|
||||||
|
tf.debugging.assert_equal(
|
||||||
|
shape_list(inputs[attn_mask])[0],
|
||||||
|
len(self.layers),
|
||||||
|
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
|
||||||
|
)
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
@@ -2311,11 +2390,15 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
||||||
|
|
||||||
hidden_states, layer_self_attn, present_key_value = decoder_layer(
|
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||||
|
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||||
|
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
|
||||||
|
if inputs["cross_attn_head_mask"] is not None
|
||||||
|
else None,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2325,23 +2408,30 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_self_attn,)
|
||||||
|
|
||||||
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns += (layer_cross_attn,)
|
||||||
|
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_self_attns = list(all_self_attns)
|
all_self_attns = list(all_self_attns)
|
||||||
|
|
||||||
|
if inputs["encoder_hidden_states"] is not None:
|
||||||
|
all_cross_attns = list(all_cross_attns)
|
||||||
|
|
||||||
if inputs["use_cache"]:
|
if inputs["use_cache"]:
|
||||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||||
else:
|
else:
|
||||||
return TFBaseModelOutputWithPast(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=present_key_values,
|
past_key_values=present_key_values,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
|
cross_attentions=all_cross_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tf.function
|
@tf.function
|
||||||
@@ -2413,6 +2503,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -2431,6 +2524,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -2450,6 +2546,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
inputs["encoder_outputs"] = self.encoder(
|
inputs["encoder_outputs"] = self.encoder(
|
||||||
input_ids=inputs["input_ids"],
|
input_ids=inputs["input_ids"],
|
||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"],
|
||||||
|
head_mask=inputs["head_mask"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=inputs["output_hidden_states"],
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
@@ -2472,6 +2569,8 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
attention_mask=inputs["decoder_attention_mask"],
|
attention_mask=inputs["decoder_attention_mask"],
|
||||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||||
encoder_attention_mask=inputs["attention_mask"],
|
encoder_attention_mask=inputs["attention_mask"],
|
||||||
|
head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
use_cache=inputs["use_cache"],
|
use_cache=inputs["use_cache"],
|
||||||
@@ -2489,6 +2588,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||||
@@ -2524,6 +2624,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -2542,6 +2645,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -2559,6 +2665,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
|||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"],
|
||||||
decoder_input_ids=inputs["decoder_input_ids"],
|
decoder_input_ids=inputs["decoder_input_ids"],
|
||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
|
head_mask=inputs["head_mask"],
|
||||||
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
encoder_outputs=inputs["encoder_outputs"],
|
encoder_outputs=inputs["encoder_outputs"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
@@ -2577,6 +2686,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -2585,6 +2695,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
@@ -2637,6 +2748,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -2672,6 +2786,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -2698,6 +2815,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||||||
decoder_input_ids=inputs["decoder_input_ids"],
|
decoder_input_ids=inputs["decoder_input_ids"],
|
||||||
encoder_outputs=inputs["encoder_outputs"],
|
encoder_outputs=inputs["encoder_outputs"],
|
||||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||||
|
head_mask=inputs["head_mask"],
|
||||||
|
decoder_head_mask=inputs["decoder_head_mask"],
|
||||||
|
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||||
past_key_values=inputs["past_key_values"],
|
past_key_values=inputs["past_key_values"],
|
||||||
inputs_embeds=inputs["inputs_embeds"],
|
inputs_embeds=inputs["inputs_embeds"],
|
||||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||||
@@ -2720,6 +2840,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||||||
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
||||||
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
||||||
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
||||||
|
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
|
||||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
||||||
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
||||||
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
||||||
@@ -2730,6 +2851,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||||
|
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||||
|
|
||||||
@@ -2738,6 +2860,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||||||
past_key_values=pkv,
|
past_key_values=pkv,
|
||||||
decoder_hidden_states=dec_hs,
|
decoder_hidden_states=dec_hs,
|
||||||
decoder_attentions=dec_attns,
|
decoder_attentions=dec_attns,
|
||||||
|
cross_attentions=cross_attns,
|
||||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||||
encoder_hidden_states=enc_hs,
|
encoder_hidden_states=enc_hs,
|
||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
|
|||||||
@@ -147,6 +147,7 @@ def prepare_bart_inputs_dict(
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||||
@@ -162,13 +163,16 @@ def prepare_bart_inputs_dict(
|
|||||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||||
if decoder_head_mask is None:
|
if decoder_head_mask is None:
|
||||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
|
if cross_attn_head_mask is None:
|
||||||
|
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
"decoder_head_mask": head_mask,
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ def prepare_blenderbot_inputs_dict(
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||||
@@ -161,6 +162,8 @@ def prepare_blenderbot_inputs_dict(
|
|||||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||||
if decoder_head_mask is None:
|
if decoder_head_mask is None:
|
||||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
|
if cross_attn_head_mask is None:
|
||||||
|
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
@@ -168,6 +171,7 @@ def prepare_blenderbot_inputs_dict(
|
|||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
"decoder_head_mask": decoder_head_mask,
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ def prepare_blenderbot_small_inputs_dict(
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||||
@@ -161,6 +162,8 @@ def prepare_blenderbot_small_inputs_dict(
|
|||||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||||
if decoder_head_mask is None:
|
if decoder_head_mask is None:
|
||||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
|
if cross_attn_head_mask is None:
|
||||||
|
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
@@ -168,6 +171,7 @@ def prepare_blenderbot_small_inputs_dict(
|
|||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
"decoder_head_mask": decoder_head_mask,
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -190,8 +190,12 @@ class TFModelTesterMixin:
|
|||||||
"decoder_attention_mask",
|
"decoder_attention_mask",
|
||||||
]
|
]
|
||||||
expected_arg_names.extend(
|
expected_arg_names.extend(
|
||||||
["head_mask", "decoder_head_mask", "encoder_outputs"]
|
["head_mask", "decoder_head_mask"] if "head_mask" and "decoder_head_mask" in arg_names else []
|
||||||
if "head_mask" and "decoder_head_mask" in arg_names
|
)
|
||||||
|
# Necessary to handle BART with newly added cross_attn_head_mask
|
||||||
|
expected_arg_names.extend(
|
||||||
|
["cross_attn_head_mask", "encoder_outputs"]
|
||||||
|
if "cross_attn_head_mask" in arg_names
|
||||||
else ["encoder_outputs"]
|
else ["encoder_outputs"]
|
||||||
)
|
)
|
||||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||||
@@ -512,6 +516,8 @@ class TFModelTesterMixin:
|
|||||||
del inputs_dict["head_mask"]
|
del inputs_dict["head_mask"]
|
||||||
if "decoder_head_mask" in inputs_dict:
|
if "decoder_head_mask" in inputs_dict:
|
||||||
del inputs_dict["decoder_head_mask"]
|
del inputs_dict["decoder_head_mask"]
|
||||||
|
if "cross_attn_head_mask" in inputs_dict:
|
||||||
|
del inputs_dict["cross_attn_head_mask"]
|
||||||
tf_main_layer_classes = set(
|
tf_main_layer_classes = set(
|
||||||
module_member
|
module_member
|
||||||
for model_class in self.all_model_classes
|
for model_class in self.all_model_classes
|
||||||
@@ -639,7 +645,7 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
def check_decoder_attentions_output(outputs):
|
def check_decoder_attentions_output(outputs):
|
||||||
out_len = len(outputs)
|
out_len = len(outputs)
|
||||||
self.assertEqual(out_len % 2, 0)
|
self.assertEqual(min(out_len % 2, out_len % 5), 0) # differentiation due to newly added cross_attentions
|
||||||
decoder_attentions = outputs.decoder_attentions
|
decoder_attentions = outputs.decoder_attentions
|
||||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
@@ -733,6 +739,8 @@ class TFModelTesterMixin:
|
|||||||
arg_names = [*signature.parameters.keys()]
|
arg_names = [*signature.parameters.keys()]
|
||||||
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
|
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
|
||||||
inputs["decoder_head_mask"] = head_mask
|
inputs["decoder_head_mask"] = head_mask
|
||||||
|
if "cross_attn_head_mask" in arg_names:
|
||||||
|
inputs["cross_attn_head_mask"] = head_mask
|
||||||
|
|
||||||
outputs = model(**inputs, return_dict=True)
|
outputs = model(**inputs, return_dict=True)
|
||||||
|
|
||||||
@@ -757,6 +765,8 @@ class TFModelTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
check_attentions_validity(outputs.encoder_attentions)
|
check_attentions_validity(outputs.encoder_attentions)
|
||||||
check_attentions_validity(outputs.decoder_attentions)
|
check_attentions_validity(outputs.decoder_attentions)
|
||||||
|
if "cross_attn_head_mask" in arg_names:
|
||||||
|
check_attentions_validity(outputs.cross_attentions)
|
||||||
else:
|
else:
|
||||||
check_attentions_validity(outputs.attentions)
|
check_attentions_validity(outputs.attentions)
|
||||||
|
|
||||||
|
|||||||
@@ -148,6 +148,7 @@ def prepare_marian_inputs_dict(
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||||
@@ -163,6 +164,8 @@ def prepare_marian_inputs_dict(
|
|||||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||||
if decoder_head_mask is None:
|
if decoder_head_mask is None:
|
||||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
|
if cross_attn_head_mask is None:
|
||||||
|
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
@@ -170,6 +173,7 @@ def prepare_marian_inputs_dict(
|
|||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
"decoder_head_mask": decoder_head_mask,
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -150,6 +150,7 @@ def prepare_mbart_inputs_dict(
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||||
@@ -165,13 +166,16 @@ def prepare_mbart_inputs_dict(
|
|||||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||||
if decoder_head_mask is None:
|
if decoder_head_mask is None:
|
||||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
|
if cross_attn_head_mask is None:
|
||||||
|
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
"decoder_head_mask": head_mask,
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ def prepare_pegasus_inputs_dict(
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||||
@@ -161,6 +162,8 @@ def prepare_pegasus_inputs_dict(
|
|||||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||||
if decoder_head_mask is None:
|
if decoder_head_mask is None:
|
||||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
|
if cross_attn_head_mask is None:
|
||||||
|
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
@@ -168,6 +171,7 @@ def prepare_pegasus_inputs_dict(
|
|||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
"decoder_head_mask": decoder_head_mask,
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user