TF BART models - Add cross_attentions to model output and fix cross-attention head masking (#10699)
* Add cross_attn_head_mask to BART * Fix cross_attentions in TFBart-like models * This commit enables returning of `cross_attentions` for TFBart-like models * It also fixes attention head masking in cross-attenion module * Update TF model templates * Fix missing , in TF model templates * Fix typo: congig -> config
This commit is contained in:
@@ -190,8 +190,12 @@ class TFModelTesterMixin:
|
||||
"decoder_attention_mask",
|
||||
]
|
||||
expected_arg_names.extend(
|
||||
["head_mask", "decoder_head_mask", "encoder_outputs"]
|
||||
if "head_mask" and "decoder_head_mask" in arg_names
|
||||
["head_mask", "decoder_head_mask"] if "head_mask" and "decoder_head_mask" in arg_names else []
|
||||
)
|
||||
# Necessary to handle BART with newly added cross_attn_head_mask
|
||||
expected_arg_names.extend(
|
||||
["cross_attn_head_mask", "encoder_outputs"]
|
||||
if "cross_attn_head_mask" in arg_names
|
||||
else ["encoder_outputs"]
|
||||
)
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
@@ -512,6 +516,8 @@ class TFModelTesterMixin:
|
||||
del inputs_dict["head_mask"]
|
||||
if "decoder_head_mask" in inputs_dict:
|
||||
del inputs_dict["decoder_head_mask"]
|
||||
if "cross_attn_head_mask" in inputs_dict:
|
||||
del inputs_dict["cross_attn_head_mask"]
|
||||
tf_main_layer_classes = set(
|
||||
module_member
|
||||
for model_class in self.all_model_classes
|
||||
@@ -639,7 +645,7 @@ class TFModelTesterMixin:
|
||||
|
||||
def check_decoder_attentions_output(outputs):
|
||||
out_len = len(outputs)
|
||||
self.assertEqual(out_len % 2, 0)
|
||||
self.assertEqual(min(out_len % 2, out_len % 5), 0) # differentiation due to newly added cross_attentions
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
@@ -733,6 +739,8 @@ class TFModelTesterMixin:
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
|
||||
inputs["decoder_head_mask"] = head_mask
|
||||
if "cross_attn_head_mask" in arg_names:
|
||||
inputs["cross_attn_head_mask"] = head_mask
|
||||
|
||||
outputs = model(**inputs, return_dict=True)
|
||||
|
||||
@@ -757,6 +765,8 @@ class TFModelTesterMixin:
|
||||
if model.config.is_encoder_decoder:
|
||||
check_attentions_validity(outputs.encoder_attentions)
|
||||
check_attentions_validity(outputs.decoder_attentions)
|
||||
if "cross_attn_head_mask" in arg_names:
|
||||
check_attentions_validity(outputs.cross_attentions)
|
||||
else:
|
||||
check_attentions_validity(outputs.attentions)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user