Make gradient_checkpointing a training argument (#13657)
* Make gradient_checkpointing a training argument * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update src/transformers/configuration_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Fix tests * Style * document Gradient Checkpointing as a performance feature * Small rename * PoC for not using the config * Adapt BC to new PoC * Forgot to save * Rollout changes to all other models * Fix typo Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
@@ -72,8 +72,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if ``config.is_decoder=True``.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
{% else -%}
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 50265):
|
||||
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
|
||||
@@ -186,7 +184,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
||||
decoder_start_token_id=2,
|
||||
classifier_dropout=0.0,
|
||||
scale_embedding=False,
|
||||
gradient_checkpointing=False,
|
||||
{% endif -%}
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
@@ -225,7 +222,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.use_cache = use_cache
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
|
||||
{% endif -%}
|
||||
|
||||
@@ -513,6 +513,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([{{cookiecutter.camelcase_modelname}}Layer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -539,12 +540,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
|
||||
"`use_cache=False`..."
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
@@ -664,6 +664,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
||||
config_class = {{cookiecutter.camelcase_modelname}}Config
|
||||
load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}}
|
||||
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
|
||||
supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
@@ -682,6 +683,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
|
||||
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
||||
@@ -2006,6 +2011,7 @@ class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
|
||||
class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
||||
config_class = {{cookiecutter.camelcase_modelname}}Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
@@ -2017,16 +2023,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
pad_token = self.config.pad_token_id
|
||||
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
|
||||
dummy_inputs = {
|
||||
"attention_mask": input_ids.ne(pad_token),
|
||||
"input_ids": input_ids,
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
|
||||
@@ -2213,6 +2213,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.init_weights()
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -2309,7 +2310,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -2376,6 +2377,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
||||
self.init_weights()
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
@@ -2545,10 +2547,10 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warning("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...")
|
||||
logger.warning("`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...")
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
Reference in New Issue
Block a user