[timm_wrapper] add support for gradient checkpointing (#39287)

* feat: add support for gradient checkpointing in TimmWrapperModel and TimmWrapperForImageClassification

* ruff fix

* refactor + add test for not supported model

* ruff

* Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

---------

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
Dominik Baran
2025-07-22 13:07:52 +02:00
committed by GitHub
parent a44dcbe513
commit 30567c28e8
2 changed files with 32 additions and 0 deletions

View File

@@ -70,6 +70,10 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
requires_backends(self, ["vision", "timm"])
super().__init__(*args, **kwargs)
def post_init(self):
self.supports_gradient_checkpointing = self._timm_model_supports_gradient_checkpointing()
super().post_init()
@staticmethod
def _fix_state_dict_key_on_load(key) -> tuple[str, bool]:
"""
@@ -107,6 +111,24 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
def _timm_model_supports_gradient_checkpointing(self):
"""
Check if the timm model supports gradient checkpointing by checking if the `set_grad_checkpointing` method is available.
Some timm models will have the method but will raise an AssertionError when called so in this case we return False.
"""
if not hasattr(self.timm_model, "set_grad_checkpointing"):
return False
try:
self.timm_model.set_grad_checkpointing(enable=True)
self.timm_model.set_grad_checkpointing(enable=False)
return True
except Exception:
return False
def _set_gradient_checkpointing(self, enable: bool = True, *args, **kwargs):
self.timm_model.set_grad_checkpointing(enable)
class TimmWrapperModel(TimmWrapperPreTrainedModel):
"""

View File

@@ -170,6 +170,16 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
def test_model_is_small(self):
pass
def test_gradient_checkpointing(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = TimmWrapperModel._from_config(config)
self.assertTrue(model.supports_gradient_checkpointing)
def test_gradient_checkpointing_on_non_supported_model(self):
config = TimmWrapperConfig.from_pretrained("timm/hrnet_w18.ms_aug_in1k")
model = TimmWrapperModel._from_config(config)
self.assertFalse(model.supports_gradient_checkpointing)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()