[timm_wrapper] add support for gradient checkpointing (#39287)

* feat: add support for gradient checkpointing in TimmWrapperModel and TimmWrapperForImageClassification

* ruff fix

* refactor + add test for not supported model

* ruff

* Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

---------

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
Dominik Baran
2025-07-22 13:07:52 +02:00
committed by GitHub
parent a44dcbe513
commit 30567c28e8
2 changed files with 32 additions and 0 deletions

View File

@@ -170,6 +170,16 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
def test_model_is_small(self):
pass
def test_gradient_checkpointing(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = TimmWrapperModel._from_config(config)
self.assertTrue(model.supports_gradient_checkpointing)
def test_gradient_checkpointing_on_non_supported_model(self):
config = TimmWrapperConfig.from_pretrained("timm/hrnet_w18.ms_aug_in1k")
model = TimmWrapperModel._from_config(config)
self.assertFalse(model.supports_gradient_checkpointing)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()