[timm_wrapper] add support for gradient checkpointing (#39287)
* feat: add support for gradient checkpointing in TimmWrapperModel and TimmWrapperForImageClassification * ruff fix * refactor + add test for not supported model * ruff * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
@@ -70,6 +70,10 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
|
||||
requires_backends(self, ["vision", "timm"])
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def post_init(self):
|
||||
self.supports_gradient_checkpointing = self._timm_model_supports_gradient_checkpointing()
|
||||
super().post_init()
|
||||
|
||||
@staticmethod
|
||||
def _fix_state_dict_key_on_load(key) -> tuple[str, bool]:
|
||||
"""
|
||||
@@ -107,6 +111,24 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _timm_model_supports_gradient_checkpointing(self):
|
||||
"""
|
||||
Check if the timm model supports gradient checkpointing by checking if the `set_grad_checkpointing` method is available.
|
||||
Some timm models will have the method but will raise an AssertionError when called so in this case we return False.
|
||||
"""
|
||||
if not hasattr(self.timm_model, "set_grad_checkpointing"):
|
||||
return False
|
||||
|
||||
try:
|
||||
self.timm_model.set_grad_checkpointing(enable=True)
|
||||
self.timm_model.set_grad_checkpointing(enable=False)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _set_gradient_checkpointing(self, enable: bool = True, *args, **kwargs):
|
||||
self.timm_model.set_grad_checkpointing(enable)
|
||||
|
||||
|
||||
class TimmWrapperModel(TimmWrapperPreTrainedModel):
|
||||
"""
|
||||
|
||||
@@ -170,6 +170,16 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
def test_model_is_small(self):
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = TimmWrapperModel._from_config(config)
|
||||
self.assertTrue(model.supports_gradient_checkpointing)
|
||||
|
||||
def test_gradient_checkpointing_on_non_supported_model(self):
|
||||
config = TimmWrapperConfig.from_pretrained("timm/hrnet_w18.ms_aug_in1k")
|
||||
model = TimmWrapperModel._from_config(config)
|
||||
self.assertFalse(model.supports_gradient_checkpointing)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user