Add head_mask/decoder_head_mask for TF BART models (#9639)
* Add head_mask/decoder_head_mask for TF BART models * Add head_mask and decoder_head_mask input arguments for TF BART-based models as a TF counterpart to the PR #9569 * Add test_headmasking functionality to tests/test_modeling_tf_common.py * TODO: Add a test to verify that we can get a gradient back for importance score computation * Remove redundant #TODO note Remove redundant #TODO note from tests/test_modeling_tf_common.py * Fix assertions * Make style * Fix ...Model input args and adjust one new test * Add back head_mask and decoder_head_mask to BART-based ...Model after the last commit * Remove head_mask ande decoder_head_mask from input_dict in TF test_train_pipeline_custom_model as these two have different shape than other input args (Necessary for passing this test) * Revert adding global_rng in test_modeling_tf_common.py
This commit is contained in:
@@ -164,6 +164,7 @@ class TFBartAttention(tf.keras.layers.Layer):
|
||||
key_value_states: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
@@ -230,6 +231,17 @@ class TFBartAttention(tf.keras.layers.Layer):
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
@@ -266,16 +278,18 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
|
||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (:obj:`tf.Tensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
@@ -331,6 +345,8 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
@@ -342,6 +358,10 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
||||
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(decoder_attention_heads,)`
|
||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
@@ -354,6 +374,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@@ -370,6 +391,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
@@ -527,6 +549,18 @@ BART_INPUTS_DOCSTRING = r"""
|
||||
the right for denoising pre-training following the paper.
|
||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||
@@ -593,6 +627,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
input_ids=None,
|
||||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@@ -617,6 +652,12 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||
@@ -635,6 +676,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -670,8 +712,15 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
# encoder layers
|
||||
for encoder_layer in self.layers:
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
@@ -680,7 +729,11 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||
continue
|
||||
|
||||
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
|
||||
hidden_states, attn = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_attentions += (attn,)
|
||||
@@ -737,6 +790,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
@@ -774,6 +829,19 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding.
|
||||
@@ -802,6 +870,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
@@ -858,6 +928,13 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
all_self_attns = () if inputs["output_attentions"] else None
|
||||
present_key_values = () if inputs["use_cache"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
@@ -875,6 +952,10 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
||||
if inputs["encoder_head_mask"] is not None
|
||||
else None,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
@@ -945,6 +1026,8 @@ class TFBartMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -963,6 +1046,8 @@ class TFBartMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -993,6 +1078,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
|
||||
inputs["encoder_outputs"] = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
@@ -1015,6 +1101,8 @@ class TFBartMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["decoder_head_mask"],
|
||||
encoder_head_mask=inputs["head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
@@ -1067,6 +1155,8 @@ class TFBartModel(TFBartPretrainedModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1085,6 +1175,8 @@ class TFBartModel(TFBartPretrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1102,6 +1194,8 @@ class TFBartModel(TFBartPretrainedModel):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
@@ -1179,6 +1273,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1207,6 +1303,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1233,6 +1331,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
@@ -1277,7 +1377,15 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
||||
encoder_attentions=enc_attns,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
@@ -1309,6 +1417,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
|
||||
@@ -167,6 +167,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
||||
key_value_states: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
@@ -233,6 +234,17 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
@@ -270,17 +282,19 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
|
||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (:obj:`tf.Tensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
@@ -336,6 +350,8 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
@@ -347,6 +363,10 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
|
||||
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(decoder_attention_heads,)`
|
||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
@@ -360,6 +380,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@@ -376,6 +397,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
@@ -524,6 +546,18 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
|
||||
:obj:`past_key_values`).
|
||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||
@@ -590,6 +624,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
||||
input_ids=None,
|
||||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@@ -614,6 +649,12 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||
@@ -632,6 +673,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -666,8 +708,15 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
||||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
# encoder layers
|
||||
for encoder_layer in self.layers:
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
@@ -676,7 +725,11 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||
continue
|
||||
|
||||
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
|
||||
hidden_states, attn = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_attentions += (attn,)
|
||||
@@ -735,6 +788,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
@@ -772,6 +827,19 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding.
|
||||
@@ -800,6 +868,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
@@ -855,6 +925,14 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
||||
all_hidden_states = ()
|
||||
all_self_attns = ()
|
||||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
@@ -871,6 +949,10 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
||||
if inputs["encoder_head_mask"] is not None
|
||||
else None,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
@@ -943,6 +1025,8 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -961,6 +1045,8 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -983,6 +1069,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
|
||||
inputs["encoder_outputs"] = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
@@ -1005,6 +1092,8 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["decoder_head_mask"],
|
||||
encoder_head_mask=inputs["head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
@@ -1070,6 +1159,8 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1088,6 +1179,8 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1105,6 +1198,8 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
@@ -1196,6 +1291,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1224,6 +1321,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1249,6 +1348,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
@@ -1295,7 +1396,15 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
@@ -1327,6 +1436,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
|
||||
@@ -166,6 +166,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
||||
key_value_states: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
@@ -232,6 +233,17 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
@@ -269,16 +281,18 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
|
||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (:obj:`tf.Tensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
@@ -335,6 +349,8 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
@@ -346,6 +362,10 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
||||
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(decoder_attention_heads,)`
|
||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
@@ -358,6 +378,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@@ -374,6 +395,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
@@ -529,6 +551,18 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
|
||||
:obj:`past_key_values`).
|
||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||
@@ -595,6 +629,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
input_ids=None,
|
||||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@@ -619,6 +654,12 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||
@@ -637,6 +678,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -672,8 +714,15 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
# encoder layers
|
||||
for encoder_layer in self.layers:
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
@@ -682,7 +731,11 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||
continue
|
||||
|
||||
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
|
||||
hidden_states, attn = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_attentions += (attn,)
|
||||
@@ -740,6 +793,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
@@ -777,6 +832,19 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding.
|
||||
@@ -805,6 +873,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
@@ -859,6 +929,13 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
all_self_attns = ()
|
||||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
@@ -875,6 +952,10 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
||||
if inputs["encoder_head_mask"] is not None
|
||||
else None,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
@@ -945,6 +1026,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -963,6 +1046,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -985,6 +1070,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
||||
inputs["encoder_outputs"] = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
@@ -1007,6 +1093,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["decoder_head_mask"],
|
||||
encoder_head_mask=inputs["head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
@@ -1059,6 +1147,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1077,6 +1167,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1094,6 +1186,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
@@ -1172,6 +1266,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1200,6 +1296,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1225,6 +1323,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
@@ -1271,7 +1371,15 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
@@ -1303,6 +1411,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
|
||||
@@ -196,6 +196,7 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
||||
key_value_states: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
@@ -262,6 +263,17 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
@@ -299,16 +311,18 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
|
||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (:obj:`tf.Tensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
@@ -365,6 +379,8 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
@@ -376,6 +392,10 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
|
||||
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(decoder_attention_heads,)`
|
||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
@@ -388,6 +408,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@@ -404,6 +425,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
@@ -548,6 +570,18 @@ MARIAN_INPUTS_DOCSTRING = r"""
|
||||
:obj:`past_key_values`).
|
||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||
@@ -612,6 +646,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
|
||||
input_ids=None,
|
||||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@@ -636,6 +671,12 @@ class TFMarianEncoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||
@@ -654,6 +695,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -688,8 +730,15 @@ class TFMarianEncoder(tf.keras.layers.Layer):
|
||||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
# encoder layers
|
||||
for encoder_layer in self.layers:
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
@@ -698,7 +747,11 @@ class TFMarianEncoder(tf.keras.layers.Layer):
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||
continue
|
||||
|
||||
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
|
||||
hidden_states, attn = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_attentions += (attn,)
|
||||
@@ -753,6 +806,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
@@ -790,6 +845,19 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding.
|
||||
@@ -818,6 +886,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
@@ -872,6 +942,14 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
||||
all_hidden_states = ()
|
||||
all_self_attns = ()
|
||||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
@@ -888,6 +966,10 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
||||
if inputs["encoder_head_mask"] is not None
|
||||
else None,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
@@ -958,6 +1040,8 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -976,6 +1060,8 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1001,6 +1087,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
|
||||
inputs["encoder_outputs"] = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
@@ -1023,6 +1110,8 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["decoder_head_mask"],
|
||||
encoder_head_mask=inputs["head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
@@ -1075,6 +1164,8 @@ class TFMarianModel(TFMarianPreTrainedModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1092,6 +1183,8 @@ class TFMarianModel(TFMarianPreTrainedModel):
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
@@ -1110,6 +1203,8 @@ class TFMarianModel(TFMarianPreTrainedModel):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
@@ -1188,6 +1283,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1216,6 +1313,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1242,6 +1341,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
@@ -1288,7 +1389,15 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
@@ -1320,6 +1429,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
|
||||
@@ -170,6 +170,7 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
||||
key_value_states: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
@@ -236,6 +237,17 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
@@ -272,17 +284,19 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
|
||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (:obj:`tf.Tensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
@@ -337,6 +351,8 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
@@ -348,6 +364,10 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
|
||||
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(decoder_attention_heads,)`
|
||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
@@ -361,6 +381,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@@ -377,6 +398,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
@@ -505,6 +527,18 @@ MBART_INPUTS_DOCSTRING = r"""
|
||||
the right for denoising pre-training following the paper.
|
||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||
@@ -601,6 +635,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
||||
input_ids=None,
|
||||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@@ -625,6 +660,12 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||
@@ -643,6 +684,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -678,8 +720,15 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
||||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
# encoder layers
|
||||
for encoder_layer in self.layers:
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
@@ -688,7 +737,11 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||
continue
|
||||
|
||||
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
|
||||
hidden_states, attn = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_attentions += (attn,)
|
||||
@@ -748,6 +801,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
@@ -785,6 +840,19 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding.
|
||||
@@ -813,6 +881,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
@@ -868,6 +938,14 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
all_hidden_states = ()
|
||||
all_self_attns = ()
|
||||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
@@ -884,6 +962,10 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
||||
if inputs["encoder_head_mask"] is not None
|
||||
else None,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
@@ -956,6 +1038,8 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -974,6 +1058,8 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1002,6 +1088,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
|
||||
inputs["encoder_outputs"] = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
@@ -1024,6 +1111,8 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["decoder_head_mask"],
|
||||
encoder_head_mask=inputs["head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
@@ -1076,6 +1165,8 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1094,6 +1185,8 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1111,6 +1204,8 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
@@ -1189,6 +1284,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1217,6 +1314,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1241,6 +1340,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
@@ -1287,7 +1388,15 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
@@ -1319,6 +1428,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
|
||||
@@ -197,6 +197,7 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
||||
key_value_states: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
@@ -263,6 +264,17 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
@@ -300,17 +312,19 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
|
||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (:obj:`tf.Tensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
@@ -366,6 +380,8 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
@@ -377,6 +393,10 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
|
||||
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(decoder_attention_heads,)`
|
||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
@@ -390,6 +410,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@@ -406,6 +427,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
@@ -553,6 +575,18 @@ PEGASUS_INPUTS_DOCSTRING = r"""
|
||||
the right for denoising pre-training following the paper.
|
||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||
@@ -618,6 +652,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
|
||||
input_ids=None,
|
||||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@@ -642,6 +677,12 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||
@@ -660,6 +701,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -694,8 +736,15 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
|
||||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
# encoder layers
|
||||
for encoder_layer in self.layers:
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
@@ -704,7 +753,11 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||
continue
|
||||
|
||||
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
|
||||
hidden_states, attn = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_attentions += (attn,)
|
||||
@@ -762,6 +815,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
@@ -799,6 +854,19 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding.
|
||||
@@ -827,6 +895,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
@@ -881,6 +951,14 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
||||
all_hidden_states = ()
|
||||
all_self_attns = ()
|
||||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
@@ -897,6 +975,10 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
||||
if inputs["encoder_head_mask"] is not None
|
||||
else None,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
@@ -969,6 +1051,8 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -987,6 +1071,8 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1012,6 +1098,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
|
||||
inputs["encoder_outputs"] = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
@@ -1034,6 +1121,8 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["decoder_head_mask"],
|
||||
encoder_head_mask=inputs["head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
@@ -1086,6 +1175,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1104,6 +1195,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1121,6 +1214,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
@@ -1199,6 +1294,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1227,6 +1324,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1253,6 +1352,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
@@ -1299,7 +1400,15 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
@@ -1331,6 +1440,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user