[core] Refactor of gradient_checkpointing (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
This commit is contained in:
Younes Belkada
2023-10-25 12:16:15 +02:00
committed by GitHub
parent 9286f0ac39
commit 06e782da4e
188 changed files with 1276 additions and 2296 deletions

View File

@@ -349,10 +349,24 @@ class ModelTesterMixin:
model.gradient_checkpointing_enable()
self.assertTrue(model.is_gradient_checkpointing)
# Loop over all modules and check that relevant modules have gradient_checkpointing set to True
for n, m in model.named_modules():
if hasattr(m, "gradient_checkpointing"):
self.assertTrue(
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True"
)
# check disable works
model.gradient_checkpointing_disable()
self.assertFalse(model.is_gradient_checkpointing)
# Loop over all modules and check that relevant modules have gradient_checkpointing set to False
for n, m in model.named_modules():
if hasattr(m, "gradient_checkpointing"):
self.assertFalse(
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False"
)
def test_save_load_fast_init_from_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
@@ -569,6 +583,13 @@ class ModelTesterMixin:
loss = model(**inputs).loss
loss.backward()
model.gradient_checkpointing_disable()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_attention_outputs(self):
if not self.has_attentions:
self.skipTest(reason="Model does not output attentions")