[timm_wrapper] add support for gradient checkpointing (#39287)
* feat: add support for gradient checkpointing in TimmWrapperModel and TimmWrapperForImageClassification * ruff fix * refactor + add test for not supported model * ruff * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
@@ -170,6 +170,16 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
def test_model_is_small(self):
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = TimmWrapperModel._from_config(config)
|
||||
self.assertTrue(model.supports_gradient_checkpointing)
|
||||
|
||||
def test_gradient_checkpointing_on_non_supported_model(self):
|
||||
config = TimmWrapperConfig.from_pretrained("timm/hrnet_w18.ms_aug_in1k")
|
||||
model = TimmWrapperModel._from_config(config)
|
||||
self.assertFalse(model.supports_gradient_checkpointing)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user