[core/ gradient_checkpointing] Refactor GC - part 2 (#27073)
* fix * more fixes * fix other models * fix long t5 * use `gradient_checkpointing_func` instead * fix copies * set `gradient_checkpointing_func` as a private attribute and retrieve previous behaviour * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * replace it with `is_gradient_checkpointing_set` * remove default * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -1042,7 +1042,7 @@ class SamVisionEncoder(nn.Module):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user