From 897a24c869e2ac2ed44f17956f1009fd8f055f5e Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 26 Jan 2021 11:02:48 +0100 Subject: [PATCH] Fix head_mask for model templates --- .../test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py index c9637cd607..b1a13c997f 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -252,6 +252,8 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte else () ) + test_head_masking = False + def setUp(self): self.model_tester = TF{{cookiecutter.camelcase_modelname}}ModelTester(self) self.config_tester = ConfigTester(self, config_class={{cookiecutter.camelcase_modelname}}Config, hidden_size=37) @@ -475,6 +477,7 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte all_generative_model_classes = (TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_head_masking = False def setUp(self): self.model_tester = TF{{cookiecutter.camelcase_modelname}}ModelTester(self)