Make gradient_checkpointing a training argument (#13657)

* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Sylvain Gugger
2021-09-22 07:51:38 -04:00
committed by GitHub
parent 75f6641eaf
commit 27d4639779
96 changed files with 531 additions and 309 deletions

View File

@@ -72,8 +72,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
{% else -%}
vocab_size (:obj:`int`, `optional`, defaults to 50265):
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
@@ -186,7 +184,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
decoder_start_token_id=2,
classifier_dropout=0.0,
scale_embedding=False,
gradient_checkpointing=False,
{% endif -%}
pad_token_id=1,
bos_token_id=0,
@@ -225,7 +222,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
{% endif -%}

View File

@@ -513,6 +513,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
super().__init__()
self.config = config
self.layer = nn.ModuleList([{{cookiecutter.camelcase_modelname}}Layer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
@@ -539,12 +540,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
@@ -664,6 +664,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config
load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}}
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
@@ -682,6 +683,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder):
module.gradient_checkpointing = value
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
@@ -2006,6 +2011,7 @@ class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
@@ -2017,16 +2023,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
}
return dummy_inputs
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)):
module.gradient_checkpointing = value
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
@@ -2213,6 +2213,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.init_weights()
self.gradient_checkpointing = False
def forward(
self,
@@ -2309,7 +2310,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
@@ -2376,6 +2377,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.embed_tokens
@@ -2545,10 +2547,10 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...")
logger.warning("`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...")
use_cache = False
def create_custom_forward(module):