[core] Refactor of gradient_checkpointing (#27020)
* v1 * fix * remove `create_custom_forward` * fixup * fixup * add test and fix all failing GC tests * remove all remaining `create_custom_forward` methods * fix idefics bug * fixup * replace with `__call__` * add comment * quality
This commit is contained in:
@@ -349,10 +349,24 @@ class ModelTesterMixin:
|
||||
model.gradient_checkpointing_enable()
|
||||
self.assertTrue(model.is_gradient_checkpointing)
|
||||
|
||||
# Loop over all modules and check that relevant modules have gradient_checkpointing set to True
|
||||
for n, m in model.named_modules():
|
||||
if hasattr(m, "gradient_checkpointing"):
|
||||
self.assertTrue(
|
||||
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True"
|
||||
)
|
||||
|
||||
# check disable works
|
||||
model.gradient_checkpointing_disable()
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
# Loop over all modules and check that relevant modules have gradient_checkpointing set to False
|
||||
for n, m in model.named_modules():
|
||||
if hasattr(m, "gradient_checkpointing"):
|
||||
self.assertFalse(
|
||||
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False"
|
||||
)
|
||||
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
@@ -569,6 +583,13 @@ class ModelTesterMixin:
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
model.gradient_checkpointing_disable()
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||
model.train()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
def test_attention_outputs(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model does not output attentions")
|
||||
|
||||
Reference in New Issue
Block a user