From e7381c4596828e5d9fa7974df14f011ed95890c1 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Tue, 9 Feb 2021 17:45:18 +0100 Subject: [PATCH] Add head_mask and decoder_head_mask to TF LED (#9988) * Add head masking to TF LED * Add head_mask to Longformer + one doc piece to LED * Fix integration tests --- .../models/led/modeling_tf_led.py | 131 +++++++++++++++++- .../longformer/modeling_tf_longformer.py | 61 +++++++- tests/test_modeling_tf_led.py | 9 +- tests/test_modeling_tf_longformer.py | 27 +++- 4 files changed, 217 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index bce2fc5316..783c4da3bb 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -200,6 +200,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ( hidden_states, attention_mask, + layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn, @@ -275,6 +276,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): attn_probs, ) + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs + # apply dropout attn_probs = self.dropout(attn_probs, training=training) value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) @@ -310,6 +319,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): attn_output=attn_output, hidden_states=hidden_states, max_num_global_attn_indices=max_num_global_attn_indices, + layer_head_mask=layer_head_mask, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, @@ -752,6 +762,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): attn_output, hidden_states, max_num_global_attn_indices, + layer_head_mask, is_local_index_global_attn_nonzero, is_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero, @@ -817,6 +828,20 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): # compute global attn probs global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1) + # apply layer head maskin + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + ) + global_attn_probs_float = tf.reshape( + global_attn_probs_float, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len) + ) + # dropout global_attn_probs = self.global_dropout(global_attn_probs_float, training=training) @@ -875,13 +900,14 @@ class TFLEDEncoderAttention(tf.keras.layers.Layer): ( hidden_states, attention_mask, + layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn, ) = inputs self_outputs = self.longformer_self_attn( - [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn], + [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn], training=training, ) @@ -927,6 +953,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer): key_value_states: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, training=False, ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: """Input shape: Batch x Time x Channel""" @@ -993,6 +1020,17 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer): attn_weights = tf.nn.softmax(attn_weights, axis=-1) + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + attn_probs = self.dropout(attn_weights, training=training) attn_output = tf.matmul(attn_probs, value_states) @@ -1031,6 +1069,7 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer): self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, + layer_head_mask: tf.Tensor, is_index_masked: tf.Tensor, is_index_global_attn: tf.Tensor, is_global_attn: bool, @@ -1041,10 +1080,12 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer): hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`tf.Tensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. """ residual = hidden_states layer_outputs = self.self_attn( - [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn], + [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn], training=training, ) @@ -1104,6 +1145,8 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer): attention_mask: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, + encoder_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -1115,6 +1158,10 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer): encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of + size `(config.encoder_attention_heads,)`. past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -1127,6 +1174,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer): hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, ) hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states @@ -1143,6 +1191,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer): hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -1438,6 +1487,18 @@ LED_INPUTS_DOCSTRING = r""" shifting the input_ids right, following the paper. decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -1517,6 +1578,7 @@ class TFLEDEncoder(tf.keras.layers.Layer): inputs_embeds=None, attention_mask=None, global_attention_mask=None, + head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -1541,6 +1603,12 @@ class TFLEDEncoder(tf.keras.layers.Layer): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -1559,6 +1627,7 @@ class TFLEDEncoder(tf.keras.layers.Layer): config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, global_attention_mask=global_attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, @@ -1617,8 +1686,15 @@ class TFLEDEncoder(tf.keras.layers.Layer): encoder_states = () if inputs["output_hidden_states"] else None all_attentions = all_global_attentions = () if inputs["output_attentions"] else None + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) # encoder layers - for encoder_layer in self.layers: + for idx, encoder_layer in enumerate(self.layers): if inputs["output_hidden_states"]: hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len) @@ -1631,6 +1707,7 @@ class TFLEDEncoder(tf.keras.layers.Layer): layer_outputs = encoder_layer( hidden_states=hidden_states, attention_mask=inputs["attention_mask"], + layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, @@ -1753,6 +1830,8 @@ class TFLEDDecoder(tf.keras.layers.Layer): attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -1784,6 +1863,19 @@ class TFLEDDecoder(tf.keras.layers.Layer): - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last @@ -1810,6 +1902,8 @@ class TFLEDDecoder(tf.keras.layers.Layer): attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -1865,6 +1959,14 @@ class TFLEDDecoder(tf.keras.layers.Layer): all_hidden_states = () all_self_attns = () present_key_values = () + + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: @@ -1881,6 +1983,10 @@ class TFLEDDecoder(tf.keras.layers.Layer): attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], + layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + encoder_layer_head_mask=inputs["encoder_head_mask"][idx] + if inputs["encoder_head_mask"] is not None + else None, past_key_value=past_key_value, ) @@ -1950,6 +2056,8 @@ class TFLEDMainLayer(tf.keras.layers.Layer): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFLEDEncoderBaseModelOutput]] = None, global_attention_mask=None, past_key_values=None, @@ -1969,6 +2077,8 @@ class TFLEDMainLayer(tf.keras.layers.Layer): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, global_attention_mask=global_attention_mask, past_key_values=past_key_values, @@ -1990,6 +2100,7 @@ class TFLEDMainLayer(tf.keras.layers.Layer): input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], global_attention_mask=inputs["global_attention_mask"], + head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], @@ -2012,6 +2123,8 @@ class TFLEDMainLayer(tf.keras.layers.Layer): attention_mask=inputs["decoder_attention_mask"], encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], + head_mask=inputs["decoder_head_mask"], + encoder_head_mask=inputs["head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -2065,6 +2178,8 @@ class TFLEDModel(TFLEDPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFLEDEncoderBaseModelOutput]] = None, global_attention_mask=None, past_key_values=None, @@ -2084,6 +2199,8 @@ class TFLEDModel(TFLEDPreTrainedModel): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, global_attention_mask=global_attention_mask, past_key_values=past_key_values, @@ -2103,6 +2220,8 @@ class TFLEDModel(TFLEDPreTrainedModel): decoder_attention_mask=inputs["decoder_attention_mask"], encoder_outputs=inputs["encoder_outputs"], global_attention_mask=inputs["global_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], @@ -2180,6 +2299,8 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[TFLEDEncoderBaseModelOutput] = None, global_attention_mask=None, past_key_values=None, @@ -2217,6 +2338,8 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, global_attention_mask=global_attention_mask, past_key_values=past_key_values, @@ -2245,6 +2368,8 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): decoder_attention_mask=inputs["decoder_attention_mask"], encoder_outputs=inputs["encoder_outputs"], global_attention_mask=inputs["global_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index 81f0eb3880..2636b479e8 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -719,6 +719,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ( hidden_states, attention_mask, + layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn, @@ -794,6 +795,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): attn_probs, ) + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs + # apply dropout attn_probs = self.dropout(attn_probs, training=training) value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) @@ -829,6 +838,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): attn_output=attn_output, hidden_states=hidden_states, max_num_global_attn_indices=max_num_global_attn_indices, + layer_head_mask=layer_head_mask, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, @@ -1271,6 +1281,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): attn_output, hidden_states, max_num_global_attn_indices, + layer_head_mask, is_local_index_global_attn_nonzero, is_index_global_attn_nonzero, is_local_index_no_global_attn_nonzero, @@ -1336,6 +1347,20 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): # compute global attn probs global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1) + # apply layer head maskin + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + ) + global_attn_probs_float = tf.reshape( + global_attn_probs_float, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len) + ) + # dropout global_attn_probs = self.global_dropout(global_attn_probs_float, training=training) @@ -1398,13 +1423,14 @@ class TFLongformerAttention(tf.keras.layers.Layer): ( hidden_states, attention_mask, + layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn, ) = inputs self_outputs = self.self_attention( - [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn], + [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn], training=training, ) attention_output = self.dense_output(self_outputs[0], hidden_states, training=training) @@ -1425,13 +1451,14 @@ class TFLongformerLayer(tf.keras.layers.Layer): ( hidden_states, attention_mask, + layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn, ) = inputs attention_outputs = self.attention( - [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn], + [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn], training=training, ) attention_output = attention_outputs[0] @@ -1469,7 +1496,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer): all_hidden_states = () if output_hidden_states else None all_attentions = all_global_attentions = () if output_attentions else None - for i, layer_module in enumerate(self.layer): + for idx, layer_module in enumerate(self.layer): if output_hidden_states: hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states all_hidden_states = all_hidden_states + (hidden_states_to_add,) @@ -1478,6 +1505,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer): [ hidden_states, attention_mask, + head_mask[idx] if head_mask is not None else None, is_index_masked, is_index_global_attn, is_global_attn, @@ -1558,6 +1586,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): self, input_ids=None, attention_mask=None, + head_mask=None, global_attention_mask=None, token_type_ids=None, position_ids=None, @@ -1573,6 +1602,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1649,6 +1679,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, + head_mask=head_mask, padding_len=padding_len, is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, @@ -1842,6 +1873,12 @@ LONGFORMER_INPUTS_DOCSTRING = r""" - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + global_attention_mask (:obj:`tf.Tensor` of shape :obj:`({0})`, `optional`): Mask to decide the attention given on each token, local attention or global attention. Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for @@ -1918,6 +1955,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel): self, input_ids=None, attention_mask=None, + head_mask=None, global_attention_mask=None, token_type_ids=None, position_ids=None, @@ -1933,6 +1971,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel): config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1946,6 +1985,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel): outputs = self.longformer( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + head_mask=inputs["head_mask"], global_attention_mask=inputs["global_attention_mask"], token_type_ids=inputs["token_type_ids"], position_ids=inputs["position_ids"], @@ -2004,6 +2044,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel self, input_ids=None, attention_mask=None, + head_mask=None, global_attention_mask=None, token_type_ids=None, position_ids=None, @@ -2026,6 +2067,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -2040,6 +2082,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel outputs = self.longformer( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + head_mask=inputs["head_mask"], global_attention_mask=inputs["global_attention_mask"], token_type_ids=inputs["token_type_ids"], position_ids=inputs["position_ids"], @@ -2109,6 +2152,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn self, input_ids=None, attention_mask=None, + head_mask=None, global_attention_mask=None, token_type_ids=None, position_ids=None, @@ -2136,6 +2180,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -2170,6 +2215,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn outputs = self.longformer( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + head_mask=inputs["head_mask"], global_attention_mask=inputs["global_attention_mask"], token_type_ids=inputs["token_type_ids"], position_ids=inputs["position_ids"], @@ -2274,6 +2320,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque self, input_ids=None, attention_mask=None, + head_mask=None, token_type_ids=None, position_ids=None, global_attention_mask=None, @@ -2290,6 +2337,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -2321,6 +2369,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque outputs = self.longformer( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + head_mask=inputs["head_mask"], global_attention_mask=inputs["global_attention_mask"], token_type_ids=inputs["token_type_ids"], position_ids=inputs["position_ids"], @@ -2397,6 +2446,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic self, input_ids=None, attention_mask=None, + head_mask=None, token_type_ids=None, position_ids=None, global_attention_mask=None, @@ -2419,6 +2469,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -2464,6 +2515,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic position_ids=flat_position_ids, token_type_ids=flat_token_type_ids, attention_mask=flat_attention_mask, + head_mask=head_mask, global_attention_mask=flat_global_attention_mask, inputs_embeds=flat_inputs_embeds, output_attentions=output_attentions, @@ -2547,6 +2599,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla self, input_ids=None, attention_mask=None, + head_mask=None, token_type_ids=None, position_ids=None, global_attention_mask=None, @@ -2568,6 +2621,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -2582,6 +2636,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla outputs = self.longformer( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + head_mask=inputs["head_mask"], global_attention_mask=inputs["global_attention_mask"], token_type_ids=inputs["token_type_ids"], position_ids=inputs["position_ids"], diff --git a/tests/test_modeling_tf_led.py b/tests/test_modeling_tf_led.py index 9299e731a9..55af7528d1 100644 --- a/tests/test_modeling_tf_led.py +++ b/tests/test_modeling_tf_led.py @@ -162,6 +162,8 @@ def prepare_led_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) @@ -173,11 +175,17 @@ def prepare_led_inputs_dict( ], axis=-1, ) + if head_mask is None: + head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) + if decoder_head_mask is None: + decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) return { "input_ids": input_ids, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, } @@ -187,7 +195,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False - test_head_masking = False def setUp(self): self.model_tester = TFLEDModelTester(self) diff --git a/tests/test_modeling_tf_longformer.py b/tests/test_modeling_tf_longformer.py index 8fddcb9c27..6b600e72e8 100644 --- a/tests/test_modeling_tf_longformer.py +++ b/tests/test_modeling_tf_longformer.py @@ -297,7 +297,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): if is_tf_available() else () ) - test_head_masking = False def setUp(self): self.model_tester = TFLongformerModelTester(self) @@ -517,8 +516,10 @@ class TFLongformerModelIntegrationTest(unittest.TestCase): attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None]) is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0) + layer_head_mask = None + output_hidden_states = layer( - [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn] + [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn] )[0] expected_slice = tf.convert_to_tensor( @@ -549,8 +550,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase): is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0) is_global_attn = tf.math.reduce_any(is_index_global_attn) + layer_head_mask = None + output_hidden_states = layer( - [hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn] + [ + hidden_states, + -tf.math.abs(attention_mask), + layer_head_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + ] )[0] self.assertTrue(output_hidden_states.shape, (2, 4, 8)) @@ -584,8 +594,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase): is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0) is_global_attn = tf.math.reduce_any(is_index_global_attn) + layer_head_mask = None + output_hidden_states, local_attentions, global_attentions = layer( - [hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn] + [ + hidden_states, + -tf.math.abs(attention_mask), + layer_head_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + ] ) self.assertEqual(local_attentions.shape, (2, 4, 2, 8))