Fix cross-attention head mask for Torch encoder-decoder models (#10605)
* Fix cross-attention head mask for Torch BART models * Fix head masking for cross-attention module for the following models: BART, Blenderbot, Blenderbot_small, M2M_100, Marian, MBart, Pegasus * Enable test_headmasking for M2M_100 model * Fix cross_head_mask for FSMT, LED and T5 * This commit fixes `head_mask` for cross-attention modules in the following models: FSMT, LED, T5 * It also contains some smaller changes in doc so that it is be perfectly clear the shape of `cross_head_mask` is the same as of `decoder_head_mask` * Update template * Fix template for BartForCausalLM * Fix cross_head_mask for Speech2Text models * Fix cross_head_mask in templates * Fix args order in BartForCausalLM template * Fix doc in BART templates * Make more explicit naming * `cross_head_mask` -> `cross_attn_head_mask` * `cross_layer_head_mask` -> `cross_attn_layer_head_mask` * Fix doc * make style quality * Fix speech2text docstring
This commit is contained in:
@@ -225,8 +225,8 @@ class ModelTesterMixin:
|
||||
"decoder_attention_mask",
|
||||
]
|
||||
expected_arg_names.extend(
|
||||
["head_mask", "decoder_head_mask", "encoder_outputs"]
|
||||
if "head_mask" and "decoder_head_mask" in arg_names
|
||||
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
|
||||
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
|
||||
else ["encoder_outputs"]
|
||||
)
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
@@ -492,6 +492,8 @@ class ModelTesterMixin:
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
|
||||
inputs["decoder_head_mask"] = head_mask
|
||||
if "cross_attn_head_mask" in arg_names:
|
||||
inputs["cross_attn_head_mask"] = head_mask
|
||||
outputs = model(**inputs, return_dict=True)
|
||||
|
||||
# Test that we can get a gradient back for importance score computation
|
||||
@@ -523,6 +525,7 @@ class ModelTesterMixin:
|
||||
if model.config.is_encoder_decoder:
|
||||
check_attentions_validity(outputs.encoder_attentions)
|
||||
check_attentions_validity(outputs.decoder_attentions)
|
||||
check_attentions_validity(outputs.cross_attentions)
|
||||
else:
|
||||
check_attentions_validity(outputs.attentions)
|
||||
|
||||
@@ -1093,7 +1096,7 @@ class ModelTesterMixin:
|
||||
|
||||
# some params shouldn't be scattered by nn.DataParallel
|
||||
# so just remove them if they are present.
|
||||
blacklist_non_batched_params = ["head_mask", "decoder_head_mask"]
|
||||
blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"]
|
||||
for k in blacklist_non_batched_params:
|
||||
inputs_dict.pop(k, None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user