Fix cross-attention head mask for Torch encoder-decoder models (#10605)

* Fix cross-attention head mask for Torch BART models

* Fix head masking for cross-attention module for the following
models: BART, Blenderbot, Blenderbot_small, M2M_100, Marian, MBart,
Pegasus

* Enable test_headmasking for M2M_100 model

* Fix cross_head_mask for FSMT, LED and T5

* This commit fixes `head_mask` for cross-attention modules
in the following models: FSMT, LED, T5

* It also contains some smaller changes in doc so that
it is be perfectly clear the shape of `cross_head_mask`
is the same as of `decoder_head_mask`

* Update template

* Fix template for BartForCausalLM

* Fix cross_head_mask for Speech2Text models

* Fix cross_head_mask in templates

* Fix args order in BartForCausalLM template

* Fix doc in BART templates

* Make more explicit naming

* `cross_head_mask` -> `cross_attn_head_mask`

* `cross_layer_head_mask` -> `cross_attn_layer_head_mask`

* Fix doc

* make style quality

* Fix speech2text docstring
This commit is contained in:
Daniel Stancl
2021-04-23 18:58:06 +02:00
committed by GitHub
parent ca6b80cadb
commit e3ff165aa5
23 changed files with 587 additions and 389 deletions

View File

@@ -55,6 +55,7 @@ def prepare_bart_inputs_dict(
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
@@ -64,6 +65,8 @@ def prepare_bart_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
@@ -71,6 +74,7 @@ def prepare_bart_inputs_dict(
"decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}

View File

@@ -45,6 +45,7 @@ def prepare_blenderbot_inputs_dict(
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
@@ -54,6 +55,8 @@ def prepare_blenderbot_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
@@ -61,6 +64,7 @@ def prepare_blenderbot_inputs_dict(
"decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}

View File

@@ -50,6 +50,7 @@ def prepare_blenderbot_small_inputs_dict(
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
@@ -59,6 +60,8 @@ def prepare_blenderbot_small_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
@@ -66,6 +69,7 @@ def prepare_blenderbot_small_inputs_dict(
"decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}

View File

@@ -225,8 +225,8 @@ class ModelTesterMixin:
"decoder_attention_mask",
]
expected_arg_names.extend(
["head_mask", "decoder_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" in arg_names
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
else ["encoder_outputs"]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
@@ -492,6 +492,8 @@ class ModelTesterMixin:
arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
inputs["decoder_head_mask"] = head_mask
if "cross_attn_head_mask" in arg_names:
inputs["cross_attn_head_mask"] = head_mask
outputs = model(**inputs, return_dict=True)
# Test that we can get a gradient back for importance score computation
@@ -523,6 +525,7 @@ class ModelTesterMixin:
if model.config.is_encoder_decoder:
check_attentions_validity(outputs.encoder_attentions)
check_attentions_validity(outputs.decoder_attentions)
check_attentions_validity(outputs.cross_attentions)
else:
check_attentions_validity(outputs.attentions)
@@ -1093,7 +1096,7 @@ class ModelTesterMixin:
# some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present.
blacklist_non_batched_params = ["head_mask", "decoder_head_mask"]
blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"]
for k in blacklist_non_batched_params:
inputs_dict.pop(k, None)

View File

@@ -113,6 +113,7 @@ def prepare_fsmt_inputs_dict(
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
@@ -120,6 +121,8 @@ def prepare_fsmt_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,

View File

@@ -52,6 +52,7 @@ def prepare_led_inputs_dict(
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
@@ -61,6 +62,8 @@ def prepare_led_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
@@ -68,6 +71,7 @@ def prepare_led_inputs_dict(
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}

View File

@@ -41,16 +41,28 @@ def prepare_m2m_100_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}
@@ -142,9 +154,10 @@ class M2M100ModelTester:
model = M2M100Model(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
@@ -217,7 +230,6 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_missing_keys = False
def setUp(self):

View File

@@ -60,6 +60,7 @@ def prepare_marian_inputs_dict(
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
@@ -69,6 +70,8 @@ def prepare_marian_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
@@ -76,6 +79,7 @@ def prepare_marian_inputs_dict(
"decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}

View File

@@ -52,6 +52,7 @@ def prepare_mbart_inputs_dict(
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
@@ -61,6 +62,8 @@ def prepare_mbart_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
@@ -68,6 +71,7 @@ def prepare_mbart_inputs_dict(
"decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}

View File

@@ -42,6 +42,7 @@ def prepare_pegasus_inputs_dict(
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
@@ -51,6 +52,8 @@ def prepare_pegasus_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
@@ -58,6 +61,7 @@ def prepare_pegasus_inputs_dict(
"decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}

View File

@@ -55,17 +55,29 @@ def prepare_speech_to_text_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_features.ne(0)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
# "input_ids": input_features,
"input_features": input_features,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}
@@ -247,7 +259,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_missing_keys = False
test_torchscript = True
@@ -316,8 +327,8 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
"decoder_attention_mask",
]
expected_arg_names.extend(
["head_mask", "decoder_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" in arg_names
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
else ["encoder_outputs"]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)