Fix usage of head masks by PT encoder-decoder models' generate() function (#11621)

* Add missing head masking for generate() function

* Add head_mask, decoder_head_mask and cross_attn_head_mask
into prepare_inputs_for_generation for generate() function
for multiple encoder-decoder models.

* Add test_genereate_with_head_masking

* [WIP] Update the new test and handle special cases

* make style

* Omit ProphetNet test so far

* make fix-copies
This commit is contained in:
Daniel Stancl
2021-05-19 01:44:53 +02:00
committed by GitHub
parent ca33278fdb
commit 680d181ce8
16 changed files with 148 additions and 4 deletions

View File

@@ -1088,6 +1088,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
self.assertIsNotNone(encoder_hidden_states.grad)
self.assertIsNotNone(encoder_attentions.grad)
def test_generate_with_head_masking(self):
"""Generating with head_masking has not been implemented for ProphetNet models yet."""
pass
@require_torch
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):