Fix head masking for TFT5 (#9877)
* Fix head_mask and decoder_head_mask in TFT5 models * Enable test_headmasking both fot TFT5 tester and TFT5EncoderOnly tester Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -344,7 +344,12 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
|
||||
# Mask heads if we want to
|
||||
if layer_head_mask is not None:
|
||||
weights = weights * layer_head_mask
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.n_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.n_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights
|
||||
|
||||
attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head)
|
||||
|
||||
@@ -711,10 +716,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
assert inputs["head_mask"] is None, "Head mask not supported"
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
assert inputs["encoder_head_mask"] is None, "Encoder head mask not supported"
|
||||
inputs["encoder_head_mask"] = [None] * self.num_hidden_layers
|
||||
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
|
||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
@@ -723,7 +724,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"])
|
||||
|
||||
for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
|
||||
for idx, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
layer_outputs = layer_module(
|
||||
@@ -733,8 +734,10 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
layer_head_mask=inputs["head_mask"][i],
|
||||
encoder_layer_head_mask=inputs["encoder_head_mask"][i],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
||||
if inputs["encoder_head_mask"] is not None
|
||||
else None,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
@@ -1057,7 +1060,7 @@ T5_ENCODER_INPUTS_DOCSTRING = r"""
|
||||
behaviors between training and evaluation).
|
||||
"""
|
||||
|
||||
__HEAD_MASK_WARNING_MSG = """
|
||||
_HEAD_MASK_WARNING_MSG = """
|
||||
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
|
||||
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
|
||||
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers,
|
||||
@@ -1133,7 +1136,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
"""
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
if head_mask is not None and decoder_head_mask is None:
|
||||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||
warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||
decoder_head_mask = head_mask
|
||||
|
||||
inputs = input_processing(
|
||||
@@ -1327,7 +1330,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
"""
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
if head_mask is not None and decoder_head_mask is None:
|
||||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||
warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||
decoder_head_mask = head_mask
|
||||
|
||||
inputs = input_processing(
|
||||
|
||||
Reference in New Issue
Block a user