[core/ gradient_checkpointing] Refactor GC - part 2 (#27073)

* fix

* more fixes

* fix other models

* fix long t5

* use `gradient_checkpointing_func` instead

* fix copies

* set `gradient_checkpointing_func` as a private attribute and retrieve previous behaviour

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* replace it with `is_gradient_checkpointing_set`

* remove default

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2023-10-27 16:15:22 +02:00
committed by GitHub
parent 5be1fb6d1f
commit ffff9e70ab
186 changed files with 242 additions and 1145 deletions

View File

@@ -544,7 +544,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@@ -675,11 +675,6 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
@@ -2021,11 +2016,6 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic
@@ -2310,7 +2300,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@@ -2543,7 +2533,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,