Add head_mask/decoder_head_mask for TF BART models (#9639)
* Add head_mask/decoder_head_mask for TF BART models * Add head_mask and decoder_head_mask input arguments for TF BART-based models as a TF counterpart to the PR #9569 * Add test_headmasking functionality to tests/test_modeling_tf_common.py * TODO: Add a test to verify that we can get a gradient back for importance score computation * Remove redundant #TODO note Remove redundant #TODO note from tests/test_modeling_tf_common.py * Fix assertions * Make style * Fix ...Model input args and adjust one new test * Add back head_mask and decoder_head_mask to BART-based ...Model after the last commit * Remove head_mask ande decoder_head_mask from input_dict in TF test_train_pipeline_custom_model as these two have different shape than other input args (Necessary for passing this test) * Revert adding global_rng in test_modeling_tf_common.py
This commit is contained in:
@@ -166,6 +166,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
||||
key_value_states: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
@@ -232,6 +233,17 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
@@ -269,16 +281,18 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
|
||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (:obj:`tf.Tensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
@@ -335,6 +349,8 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
@@ -346,6 +362,10 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
||||
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(decoder_attention_heads,)`
|
||||
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
@@ -358,6 +378,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@@ -374,6 +395,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
@@ -529,6 +551,18 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
|
||||
:obj:`past_key_values`).
|
||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||
@@ -595,6 +629,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
input_ids=None,
|
||||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@@ -619,6 +654,12 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||
@@ -637,6 +678,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -672,8 +714,15 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
# encoder layers
|
||||
for encoder_layer in self.layers:
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
@@ -682,7 +731,11 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||
continue
|
||||
|
||||
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
|
||||
hidden_states, attn = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_attentions += (attn,)
|
||||
@@ -740,6 +793,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
@@ -777,6 +832,19 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding.
|
||||
@@ -805,6 +873,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
@@ -859,6 +929,13 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
all_self_attns = ()
|
||||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
@@ -875,6 +952,10 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
||||
if inputs["encoder_head_mask"] is not None
|
||||
else None,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
@@ -945,6 +1026,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -963,6 +1046,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -985,6 +1070,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
||||
inputs["encoder_outputs"] = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
@@ -1007,6 +1093,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["decoder_head_mask"],
|
||||
encoder_head_mask=inputs["head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
@@ -1059,6 +1147,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1077,6 +1167,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1094,6 +1186,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
@@ -1172,6 +1266,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1200,6 +1296,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1225,6 +1323,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
@@ -1271,7 +1371,15 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
@@ -1303,6 +1411,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user