[core/ gradient_checkpointing] Refactor GC - part 2 (#27073)
* fix * more fixes * fix other models * fix long t5 * use `gradient_checkpointing_func` instead * fix copies * set `gradient_checkpointing_func` as a private attribute and retrieve previous behaviour * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * replace it with `is_gradient_checkpointing_set` * remove default * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -544,7 +544,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@@ -675,11 +675,6 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
|
||||
@@ -2021,11 +2016,6 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic
|
||||
@@ -2310,7 +2300,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@@ -2543,7 +2533,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
||||
Reference in New Issue
Block a user