From 74026b473e8748706a7a86fd20d6a275306d8ffb Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 17 Sep 2024 17:39:34 +0800 Subject: [PATCH] =?UTF-8?q?idefics2=20enable=5Finput=5Frequire=5Fgrads=20n?= =?UTF-8?q?ot=20aligned=20with=20disable=5Finput=5Fre=E2=80=A6=20(#33194)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * idefics2 enable_input_require_grads not aligned with disable_input_require_grads make peft+idefics2 checkpoints disable fail Signed-off-by: Wang, Yi * split test case Signed-off-by: Wang, Yi * fix ci failure Signed-off-by: Wang, Yi * refine test Signed-off-by: Wang, Yi --------- Signed-off-by: Wang, Yi --- .../models/idefics2/modeling_idefics2.py | 8 ++++ .../models/speecht5/test_modeling_speecht5.py | 12 ++++++ tests/test_modeling_common.py | 38 +++++++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index f57bdd27fe..08ada424ea 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1256,6 +1256,10 @@ class Idefics2Model(Idefics2PreTrainedModel): make_inputs_require_grads ) + def disable_input_require_grads(self): + self._text_require_grads_hook.remove() + self._vision_require_grads_hook.remove() + def get_input_embeddings(self): return self.text_model.get_input_embeddings() @@ -1466,6 +1470,10 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel): make_inputs_require_grads ) + def disable_input_require_grads(self): + self._text_require_grads_hook.remove() + self._vision_require_grads_hook.remove() + def get_input_embeddings(self): return self.model.text_model.get_input_embeddings() diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 7a8aab8327..e13cf8dd56 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -239,6 +239,12 @@ class SpeechT5ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase def test_torchscript_simple(self): pass + @unittest.skip( + reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527" + ) + def test_peft_gradient_checkpointing_enable_disable(self): + pass + @require_torch class SpeechT5ForSpeechToTextTester: @@ -1743,6 +1749,12 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase): def test_training_gradient_checkpointing_use_reentrant_false(self): pass + @unittest.skip( + reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527" + ) + def test_peft_gradient_checkpointing_enable_disable(self): + pass + # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index da0570290c..c7af0b1c9f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -403,6 +403,44 @@ class ModelTesterMixin: m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" ) + def test_peft_gradient_checkpointing_enable_disable(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if not model_class.supports_gradient_checkpointing: + continue + + # at init model should have gradient checkpointing disabled + model = model_class(config) + self.assertFalse(model.is_gradient_checkpointing) + + # check enable works + model._hf_peft_config_loaded = True + try: + model.gradient_checkpointing_enable() + except NotImplementedError: + continue + + self.assertTrue(model.is_gradient_checkpointing) + + # Loop over all modules and check that relevant modules have gradient_checkpointing set to True + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertTrue( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True" + ) + + # check disable works + model.gradient_checkpointing_disable() + self.assertFalse(model.is_gradient_checkpointing) + + # Loop over all modules and check that relevant modules have gradient_checkpointing set to False + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertFalse( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" + ) + @is_flaky(description="low likelihood of failure, reason not yet discovered") def test_save_load_fast_init_from_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()