Add head_mask and decoder_head_mask to FSMT (#9819)

* Add {decoder_,}head_mask to fsmt_modeling.py

* Enable test_headmasking and some changes to docs

* Remove test_head_masking flag from fsmt test file

Remove test_head_masking flag from test_modeling_fsmt.py
since test_head_masking is set to be True by default (thus it is redundant to store).

* Merge master and remove test_head_masking = True

* Rebase necessary due to an update of jaxlib

* Remove test_head_masking=True in tests/test_modeling_fsmt.py
as it is redundant.
This commit is contained in:
Daniel Stancl
2021-02-01 07:30:21 +01:00
committed by GitHub
parent 74f16b8276
commit 0c6c0afc0e
2 changed files with 109 additions and 17 deletions

View File

@@ -111,12 +111,20 @@ def prepare_fsmt_inputs_dict(
config,
input_ids,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}
@@ -126,7 +134,6 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_missing_keys = False
def setUp(self):