[core] Refactor of gradient_checkpointing (#27020)
* v1 * fix * remove `create_custom_forward` * fixup * fixup * add test and fix all failing GC tests * remove all remaining `create_custom_forward` methods * fix idefics bug * fixup * replace with `__call__` * add comment * quality
This commit is contained in:
@@ -1042,15 +1042,8 @@ class SamVisionEncoder(nn.Module):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer_module),
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user