[core/ gradient_checkpointing] Refactor GC - part 2 (#27073)

* fix

* more fixes

* fix other models

* fix long t5

* use `gradient_checkpointing_func` instead

* fix copies

* set `gradient_checkpointing_func` as a private attribute and retrieve previous behaviour

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* replace it with `is_gradient_checkpointing_set`

* remove default

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2023-10-27 16:15:22 +02:00
committed by GitHub
parent 5be1fb6d1f
commit ffff9e70ab
186 changed files with 242 additions and 1145 deletions

View File

@@ -1042,7 +1042,7 @@ class SamVisionEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
)