[core] Refactor of gradient_checkpointing (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
This commit is contained in:
Younes Belkada
2023-10-25 12:16:15 +02:00
committed by GitHub
parent 9286f0ac39
commit 06e782da4e
188 changed files with 1276 additions and 2296 deletions

View File

@@ -1042,15 +1042,8 @@ class SamVisionEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
)
else: