From 8d79e5ca49ea27ded98de927d220d830f34b7124 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Wed, 17 Feb 2021 17:00:09 +0100 Subject: [PATCH] Fix head masking for TFT5 (#9877) * Fix head_mask and decoder_head_mask in TFT5 models * Enable test_headmasking both fot TFT5 tester and TFT5EncoderOnly tester Co-authored-by: patrickvonplaten --- src/transformers/models/t5/modeling_tf_t5.py | 25 +++++++++++--------- tests/test_modeling_tf_t5.py | 2 -- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index db58a10af4..d057ccc9cb 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -344,7 +344,12 @@ class TFT5Attention(tf.keras.layers.Layer): # Mask heads if we want to if layer_head_mask is not None: - weights = weights * layer_head_mask + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.n_heads], + message=f"Head mask for a single layer should be of size {(self.n_heads)}, but is {shape_list(layer_head_mask)}", + ) + weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head) @@ -711,10 +716,6 @@ class TFT5MainLayer(tf.keras.layers.Layer): else: encoder_extended_attention_mask = None - assert inputs["head_mask"] is None, "Head mask not supported" - inputs["head_mask"] = [None] * self.num_hidden_layers - assert inputs["encoder_head_mask"] is None, "Encoder head mask not supported" - inputs["encoder_head_mask"] = [None] * self.num_hidden_layers present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None all_hidden_states = () if inputs["output_hidden_states"] else None all_attentions = () if inputs["output_attentions"] else None @@ -723,7 +724,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"]) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])): + for idx, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])): if inputs["output_hidden_states"]: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = layer_module( @@ -733,8 +734,10 @@ class TFT5MainLayer(tf.keras.layers.Layer): encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=inputs["head_mask"][i], - encoder_layer_head_mask=inputs["encoder_head_mask"][i], + layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + encoder_layer_head_mask=inputs["encoder_head_mask"][idx] + if inputs["encoder_head_mask"] is not None + else None, past_key_value=past_key_value, use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], @@ -1057,7 +1060,7 @@ T5_ENCODER_INPUTS_DOCSTRING = r""" behaviors between training and evaluation). """ -__HEAD_MASK_WARNING_MSG = """ +_HEAD_MASK_WARNING_MSG = """ The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers, @@ -1133,7 +1136,7 @@ class TFT5Model(TFT5PreTrainedModel): """ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask if head_mask is not None and decoder_head_mask is None: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) decoder_head_mask = head_mask inputs = input_processing( @@ -1327,7 +1330,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling """ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask if head_mask is not None and decoder_head_mask is None: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) decoder_head_mask = head_mask inputs = input_processing( diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index b611c8553c..fb215a3880 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -248,7 +248,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): is_encoder_decoder = True all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else () all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else () - test_head_masking = False test_onnx = False def setUp(self): @@ -427,7 +426,6 @@ class TFT5EncoderOnlyModelTester: class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): is_encoder_decoder = False all_model_classes = (TFT5EncoderModel,) if is_tf_available() else () - test_head_masking = False test_onnx = False def setUp(self):