Fix head masking for TFT5 (#9877)
* Fix head_mask and decoder_head_mask in TFT5 models * Enable test_headmasking both fot TFT5 tester and TFT5EncoderOnly tester Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -344,7 +344,12 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# Mask heads if we want to
|
# Mask heads if we want to
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
weights = weights * layer_head_mask
|
tf.debugging.assert_equal(
|
||||||
|
shape_list(layer_head_mask),
|
||||||
|
[self.n_heads],
|
||||||
|
message=f"Head mask for a single layer should be of size {(self.n_heads)}, but is {shape_list(layer_head_mask)}",
|
||||||
|
)
|
||||||
|
weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights
|
||||||
|
|
||||||
attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head)
|
attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head)
|
||||||
|
|
||||||
@@ -711,10 +716,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
encoder_extended_attention_mask = None
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
assert inputs["head_mask"] is None, "Head mask not supported"
|
|
||||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
|
||||||
assert inputs["encoder_head_mask"] is None, "Encoder head mask not supported"
|
|
||||||
inputs["encoder_head_mask"] = [None] * self.num_hidden_layers
|
|
||||||
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
|
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
|
||||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||||
all_attentions = () if inputs["output_attentions"] else None
|
all_attentions = () if inputs["output_attentions"] else None
|
||||||
@@ -723,7 +724,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"])
|
hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"])
|
||||||
|
|
||||||
for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
|
for idx, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
@@ -733,8 +734,10 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
encoder_attention_mask=encoder_extended_attention_mask,
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||||
layer_head_mask=inputs["head_mask"][i],
|
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||||
encoder_layer_head_mask=inputs["encoder_head_mask"][i],
|
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
||||||
|
if inputs["encoder_head_mask"] is not None
|
||||||
|
else None,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
use_cache=inputs["use_cache"],
|
use_cache=inputs["use_cache"],
|
||||||
output_attentions=inputs["output_attentions"],
|
output_attentions=inputs["output_attentions"],
|
||||||
@@ -1057,7 +1060,7 @@ T5_ENCODER_INPUTS_DOCSTRING = r"""
|
|||||||
behaviors between training and evaluation).
|
behaviors between training and evaluation).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__HEAD_MASK_WARNING_MSG = """
|
_HEAD_MASK_WARNING_MSG = """
|
||||||
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
|
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
|
||||||
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
|
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
|
||||||
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers,
|
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers,
|
||||||
@@ -1133,7 +1136,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||||
if head_mask is not None and decoder_head_mask is None:
|
if head_mask is not None and decoder_head_mask is None:
|
||||||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||||
decoder_head_mask = head_mask
|
decoder_head_mask = head_mask
|
||||||
|
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
@@ -1327,7 +1330,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
"""
|
"""
|
||||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||||
if head_mask is not None and decoder_head_mask is None:
|
if head_mask is not None and decoder_head_mask is None:
|
||||||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||||
decoder_head_mask = head_mask
|
decoder_head_mask = head_mask
|
||||||
|
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
|
|||||||
@@ -248,7 +248,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
|
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
|
||||||
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
|
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
|
||||||
test_head_masking = False
|
|
||||||
test_onnx = False
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@@ -427,7 +426,6 @@ class TFT5EncoderOnlyModelTester:
|
|||||||
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
|
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
|
||||||
test_head_masking = False
|
|
||||||
test_onnx = False
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user