Fix usage of head masks by TF encoder-decoder models' generate() function (#11775)

* Fix Bart

* Fix Blenderbot{,_small}

* Fix LED

* Fix Marian

* Fix MBart

* Fix Pegasus

* Fix T5

* Add test for generation with head_mask

* Add a common TF test

* Override a test for the LED model as head masking is not yet properly implemented

* Remove all head_masks from input preparation for LED

* Drop masking for T5 as it needs a bit of refactor
This commit is contained in:
Daniel Stancl
2021-05-26 15:02:44 +02:00
committed by GitHub
parent 0b0a598452
commit 0b93358447
11 changed files with 84 additions and 2 deletions

View File

@@ -310,6 +310,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFT5Model.from_pretrained("t5-small")
self.assertIsNotNone(model)
def test_generate_with_headmasking(self):
# TODO: Fix head-masking according to PyTorch T5 model
pass
class TFT5EncoderOnlyModelTester:
def __init__(