Add head_mask, decoder_head_mask, cross_head_mask to ProphetNet (#9964)

* Add head_mask & decoder_head_mask + some corrections

* Fix head masking for N-grams

* Enable test_headmasking for encoder and decod

* Fix one typo regarding in modeling_propgetnet.py

* Enable test_headmasking for ProphetNetStandaloneDecoderModelTest
and ProphetNetStandaloneEncoderModelTest in test_modeling_prophetnet.py

* make style

* Fix cross_head_mask

* Fix attention head mask naming

* `cross_head_mask` -> `cross_attn_head_mask`

* `cross_layer_head_mask` -> `cross_attn_layer_head_mask`

* Still need to merge #10605 to master to pass the tests
This commit is contained in:
Daniel Stancl
2021-04-25 11:06:16 +02:00
committed by GitHub
parent 52166f672e
commit f45cb66bf6
2 changed files with 145 additions and 8 deletions

View File

@@ -891,7 +891,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_headmasking = False
is_encoder_decoder = True
def setUp(self):
@@ -1097,7 +1096,6 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_headmasking = False
is_encoder_decoder = False
def setUp(self):
@@ -1126,7 +1124,6 @@ class ProphetNetStandaloneEncoderModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_headmasking = False
is_encoder_decoder = False
def setUp(self):