From 30567c28e81be1ba09249aa5589b8227653ab073 Mon Sep 17 00:00:00 2001 From: Dominik Baran Date: Tue, 22 Jul 2025 13:07:52 +0200 Subject: [PATCH] [timm_wrapper] add support for gradient checkpointing (#39287) * feat: add support for gradient checkpointing in TimmWrapperModel and TimmWrapperForImageClassification * ruff fix * refactor + add test for not supported model * ruff * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii --------- Co-authored-by: Pavel Iakubovskii --- .../timm_wrapper/modeling_timm_wrapper.py | 22 +++++++++++++++++++ .../test_modeling_timm_wrapper.py | 10 +++++++++ 2 files changed, 32 insertions(+) diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 34893bfdf9..d6d844af47 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -70,6 +70,10 @@ class TimmWrapperPreTrainedModel(PreTrainedModel): requires_backends(self, ["vision", "timm"]) super().__init__(*args, **kwargs) + def post_init(self): + self.supports_gradient_checkpointing = self._timm_model_supports_gradient_checkpointing() + super().post_init() + @staticmethod def _fix_state_dict_key_on_load(key) -> tuple[str, bool]: """ @@ -107,6 +111,24 @@ class TimmWrapperPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() + def _timm_model_supports_gradient_checkpointing(self): + """ + Check if the timm model supports gradient checkpointing by checking if the `set_grad_checkpointing` method is available. + Some timm models will have the method but will raise an AssertionError when called so in this case we return False. + """ + if not hasattr(self.timm_model, "set_grad_checkpointing"): + return False + + try: + self.timm_model.set_grad_checkpointing(enable=True) + self.timm_model.set_grad_checkpointing(enable=False) + return True + except Exception: + return False + + def _set_gradient_checkpointing(self, enable: bool = True, *args, **kwargs): + self.timm_model.set_grad_checkpointing(enable) + class TimmWrapperModel(TimmWrapperPreTrainedModel): """ diff --git a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py index a37f10d381..b7653f4e77 100644 --- a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py +++ b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py @@ -170,6 +170,16 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC def test_model_is_small(self): pass + def test_gradient_checkpointing(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = TimmWrapperModel._from_config(config) + self.assertTrue(model.supports_gradient_checkpointing) + + def test_gradient_checkpointing_on_non_supported_model(self): + config = TimmWrapperConfig.from_pretrained("timm/hrnet_w18.ms_aug_in1k") + model = TimmWrapperModel._from_config(config) + self.assertFalse(model.supports_gradient_checkpointing) + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common()