Fix head masking for TFT5 (#9877)

* Fix head_mask and decoder_head_mask in TFT5 models

* Enable test_headmasking both fot TFT5 tester
and TFT5EncoderOnly tester

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Daniel Stancl
2021-02-17 17:00:09 +01:00
committed by GitHub
parent 4b91965731
commit 8d79e5ca49
2 changed files with 14 additions and 13 deletions

View File

@@ -248,7 +248,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = True
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False
def setUp(self):
@@ -427,7 +426,6 @@ class TFT5EncoderOnlyModelTester:
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False
def setUp(self):