diff --git a/docs/source/model_doc/led.rst b/docs/source/model_doc/led.rst index 2e05163d37..1eaa9e325f 100644 --- a/docs/source/model_doc/led.rst +++ b/docs/source/model_doc/led.rst @@ -46,8 +46,8 @@ Tips: - LED makes use of *global attention* by means of the ``global_attention_mask`` (see :class:`~transformers.LongformerModel`). For summarization, it is advised to put *global attention* only on the first ```` token. For question answering, it is advised to put *global attention* on all tokens of the question. -- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by setting - ``config.gradient_checkpointing = True``. +- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing + ``model.gradient_checkpointing_enable()``. - A notebook showing how to evaluate LED, can be accessed `here `__. - A notebook showing how to fine-tune LED, can be accessed `here diff --git a/docs/source/performance.md b/docs/source/performance.md index 4f479d8575..c3239f3b0c 100644 --- a/docs/source/performance.md +++ b/docs/source/performance.md @@ -53,6 +53,7 @@ Software: - Tensor Parallelism - Low-memory Optimizers - fp16/bf16 (smaller data) +- Gradient checkpointing @@ -226,6 +227,21 @@ pytorch `autocast` which performs AMP include a caching feature, which speed thi Autocast maintains a cache of the FP16 casts of model params (leaves). This helps streamline parameter reuse: if the same FP32 param is used in several different FP16list ops, like several matmuls, instead of re-casting the param to FP16 on entering each matmul, the cast will occur on the first matmul, the casted FP16 copy will be cached, and for all later matmuls the FP16 copy will be reused. The cache is maintained only within a particular outermost autocast context. When you exit the autocast context the cache is dropped. For recommended usage, in which autocast wraps the forward pass, and then you exit the context before calling backward(), this means the cache only lasts the duration of the forward pass each iteration, and will be rebuilt next iteration. (The cache of FP16-casted copies MUST be rebuilt each iteration. The FP32 params get updated by the optimizer, so the FP16 copies must be recreated, otherwise the FP16 values will be stale.) + +### Gradient Checkpointing + +One way to use significantly less GPU memory is to enabled "Gradient Checkpointing" (also known as "activation checkpointing"). When enabled, a lot of memory can be freed at the cost of small decrease in the training speed due to recomputing parts of the graph during back-propagation. + +This technique was first shared in the paper: [Training Deep Nets with Sublinear Memory Cost](https://arxiv.org/abs/1604.06174). The paper will also give you the exact details on the savings, but it's in the ballpark of `O(sqrt(n))`, where `n` is the number of feed-forward layers. + +To activate this feature in 🤗 Transformers for models that support it, use: + +```python +model.gradient_checkpointing_enable() +``` +or add `--gradient_checkpointing` to the Trainer arguments. + + ### Batch sizes One gets the most efficient performance when batch sizes and input/output neuron counts are divisible by a certain number, which typically starts at 8, but can be much higher as well. That number varies a lot depending on the specific hardware being used and the dtype of the model. diff --git a/examples/pytorch/language-modeling/README.md b/examples/pytorch/language-modeling/README.md index 23989d7ed1..c768f5ec31 100644 --- a/examples/pytorch/language-modeling/README.md +++ b/examples/pytorch/language-modeling/README.md @@ -174,8 +174,3 @@ python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides=" ``` This feature is only available in `run_clm.py`, `run_plm.py` and `run_mlm.py`. - -This feature can also be used to activate gradient checkpointing by passing: -``` ---config_overrides "gradient_checkpointing=true,use_cache=False" -``` diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 45683ac801..bc3ecf77ba 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -19,6 +19,7 @@ import copy import json import os +import warnings from typing import Any, Dict, Tuple, Union from . import __version__ @@ -330,6 +331,14 @@ class PretrainedConfig(PushToHubMixin): # Drop the transformers version info self.transformers_version = kwargs.pop("transformers_version", None) + # Deal with gradient checkpointing + if "gradient_checkpointing" in kwargs: + warnings.warn( + "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " + "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " + "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`." + ) + # Additional attributes without default values for key, value in kwargs.items(): try: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e82d0ad9e3..21a1b09f30 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -20,6 +20,7 @@ import re import warnings from contextlib import contextmanager from dataclasses import dataclass +from functools import partial from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch @@ -450,6 +451,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix _keys_to_ignore_on_save = None is_parallelizable = False + supports_gradient_checkpointing = False @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: @@ -469,6 +471,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Save config and origin of the pretrained weights if given in model self.config = config self.name_or_path = config.name_or_path + if getattr(self.config, "gradient_checkpointing", False): + self.gradient_checkpointing_enable() + # Remove the attribute now that is has been consumed, so it's no saved in the config. + delattr(self.config, "gradient_checkpointing") @classmethod def _from_config(cls, config, **kwargs): @@ -932,6 +938,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix self.base_model._prune_heads(heads_to_prune) + def gradient_checkpointing_enable(self, flag: bool = True): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + self.apply(partial(self._set_gradient_checkpointing, value=True)) + + def gradient_checkpointing_disable(self, flag: bool = True): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self.supports_gradient_checkpointing: + self.apply(partial(self._set_gradient_checkpointing, value=False)) + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index e26afb2ab4..6efbe4ca51 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -82,8 +82,6 @@ class BartConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -131,7 +129,6 @@ class BartConfig(PretrainedConfig): init_std=0.02, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, use_cache=True, num_labels=3, pad_token_id=1, @@ -161,7 +158,6 @@ class BartConfig(PretrainedConfig): self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index a466be30a6..134669cee4 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -471,6 +471,7 @@ class BartClassificationHead(nn.Module): class BartPretrainedModel(PreTrainedModel): config_class = BartConfig base_model_prefix = "model" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] def _init_weights(self, module): @@ -484,6 +485,10 @@ class BartPretrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BartDecoder, BartEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -687,6 +692,7 @@ class BartEncoder(BartPretrainedModel): self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -782,7 +788,7 @@ class BartEncoder(BartPretrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -849,6 +855,7 @@ class BartDecoder(BartPretrainedModel): self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -1020,12 +1027,11 @@ class BartDecoder(BartPretrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/beit/configuration_beit.py b/src/transformers/models/beit/configuration_beit.py index 08ecc60646..d31f83dd3a 100644 --- a/src/transformers/models/beit/configuration_beit.py +++ b/src/transformers/models/beit/configuration_beit.py @@ -57,8 +57,6 @@ class BeitConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. image_size (:obj:`int`, `optional`, defaults to :obj:`224`): The size (resolution) of each image. patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 236551d27c..1ad3fcd1e6 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -432,6 +432,7 @@ class BeitEncoder(nn.Module): for i in range(config.num_hidden_layers) ] ) + self.gradient_checkpointing = False def forward( self, @@ -450,7 +451,7 @@ class BeitEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -494,6 +495,7 @@ class BeitPreTrainedModel(PreTrainedModel): config_class = BeitConfig base_model_prefix = "beit" + supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -511,6 +513,10 @@ class BeitPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BeitEncoder): + module.gradient_checkpointing = value + BEIT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use diff --git a/src/transformers/models/bert/configuration_bert.py b/src/transformers/models/bert/configuration_bert.py index 8359f0c3b7..861cdfbc8e 100644 --- a/src/transformers/models/bert/configuration_bert.py +++ b/src/transformers/models/bert/configuration_bert.py @@ -92,8 +92,6 @@ class BertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on @@ -137,7 +135,6 @@ class BertConfig(PretrainedConfig): initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, position_embedding_type="absolute", use_cache=True, classifier_dropout=None, @@ -157,7 +154,6 @@ class BertConfig(PretrainedConfig): self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index f02d67a31a..ecb0d184a4 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -529,6 +529,7 @@ class BertEncoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -555,12 +556,11 @@ class BertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -714,6 +714,7 @@ class BertPreTrainedModel(PreTrainedModel): config_class = BertConfig load_tf_weights = load_tf_weights_in_bert base_model_prefix = "bert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -732,6 +733,10 @@ class BertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + @dataclass class BertForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/bert_generation/configuration_bert_generation.py b/src/transformers/models/bert_generation/configuration_bert_generation.py index 54659f4394..2284f873e7 100644 --- a/src/transformers/models/bert_generation/configuration_bert_generation.py +++ b/src/transformers/models/bert_generation/configuration_bert_generation.py @@ -52,8 +52,6 @@ class BertGenerationConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on @@ -96,7 +94,6 @@ class BertGenerationConfig(PretrainedConfig): pad_token_id=0, bos_token_id=2, eos_token_id=1, - gradient_checkpointing=False, position_embedding_type="absolute", use_cache=True, **kwargs @@ -114,6 +111,5 @@ class BertGenerationConfig(PretrainedConfig): self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache diff --git a/src/transformers/models/big_bird/configuration_big_bird.py b/src/transformers/models/big_bird/configuration_big_bird.py index e6fdfd1d14..85dd8de7dd 100644 --- a/src/transformers/models/big_bird/configuration_big_bird.py +++ b/src/transformers/models/big_bird/configuration_big_bird.py @@ -82,8 +82,6 @@ class BigBirdConfig(PretrainedConfig): num_random_blocks (:obj:`int`, `optional`, defaults to 3) Each query is going to attend these many number of random blocks. Useful only when :obj:`attention_type == "block_sparse"`. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. classifier_dropout (:obj:`float`, `optional`): The dropout ratio for the classification head. @@ -127,7 +125,6 @@ class BigBirdConfig(PretrainedConfig): rescale_embeddings=False, block_size=64, num_random_blocks=3, - gradient_checkpointing=False, classifier_dropout=None, **kwargs ): @@ -153,7 +150,6 @@ class BigBirdConfig(PretrainedConfig): self.layer_norm_eps = layer_norm_eps self.use_cache = use_cache self.is_encoder_decoder = is_encoder_decoder - self.gradient_checkpointing = gradient_checkpointing self.rescale_embeddings = rescale_embeddings self.attention_type = attention_type diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index f7d0d857bc..84a428591e 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1555,6 +1555,7 @@ class BigBirdEncoder(nn.Module): self.layer = nn.ModuleList( [BigBirdLayer(config, seed=layer_idx) for layer_idx in range(config.num_hidden_layers)] ) + self.gradient_checkpointing = False def set_attention_type(self, value: str): if value not in ["original_full", "block_sparse"]: @@ -1598,12 +1599,11 @@ class BigBirdEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -1756,6 +1756,7 @@ class BigBirdPreTrainedModel(PreTrainedModel): config_class = BigBirdConfig load_tf_weights = load_tf_weights_in_big_bird base_model_prefix = "bert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -1774,6 +1775,10 @@ class BigBirdPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BigBirdEncoder): + module.gradient_checkpointing = value + BIG_BIRD_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py index 28211c9b16..297e2cede4 100644 --- a/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py @@ -94,8 +94,6 @@ class BigBirdPegasusConfig(PretrainedConfig): "block_sparse"`. scale_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`) Whether to rescale embeddings with (hidden_size ** 0.5). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -141,7 +139,6 @@ class BigBirdPegasusConfig(PretrainedConfig): decoder_start_token_id=2, classifier_dropout=0.0, scale_embedding=True, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=2, eos_token_id=1, @@ -170,7 +167,6 @@ class BigBirdPegasusConfig(PretrainedConfig): self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True # extra config diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 2fd765eb5d..536cd784da 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1567,6 +1567,7 @@ class BigBirdPegasusClassificationHead(nn.Module): class BigBirdPegasusPreTrainedModel(PreTrainedModel): config_class = BigBirdPegasusConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -1579,6 +1580,10 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -1764,6 +1769,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -1894,7 +1900,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -2054,6 +2060,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -2225,12 +2232,11 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/blenderbot/configuration_blenderbot.py b/src/transformers/models/blenderbot/configuration_blenderbot.py index c2b272af03..13acbdf699 100644 --- a/src/transformers/models/blenderbot/configuration_blenderbot.py +++ b/src/transformers/models/blenderbot/configuration_blenderbot.py @@ -78,8 +78,6 @@ class BlenderbotConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -128,7 +126,6 @@ class BlenderbotConfig(PretrainedConfig): decoder_start_token_id=1, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -155,7 +152,6 @@ class BlenderbotConfig(PretrainedConfig): self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index e6bc6f6571..11e866594a 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -451,6 +451,7 @@ class BlenderbotDecoderLayer(nn.Module): class BlenderbotPreTrainedModel(PreTrainedModel): config_class = BlenderbotConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -463,6 +464,10 @@ class BlenderbotPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -644,6 +649,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -738,7 +744,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -808,6 +814,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -980,12 +987,11 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py b/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py index de8927a4ff..0f76e2e3ae 100644 --- a/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py @@ -78,8 +78,6 @@ class BlenderbotSmallConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -128,7 +126,6 @@ class BlenderbotSmallConfig(PretrainedConfig): decoder_start_token_id=1, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -154,7 +151,6 @@ class BlenderbotSmallConfig(PretrainedConfig): self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 81188488fe..a15c8276c3 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -449,6 +449,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): class BlenderbotSmallPreTrainedModel(PreTrainedModel): config_class = BlenderbotSmallConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -461,6 +462,10 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -645,6 +650,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -740,7 +746,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -808,6 +814,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -981,12 +988,11 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/canine/configuration_canine.py b/src/transformers/models/canine/configuration_canine.py index 3feef5ac75..79be54a824 100644 --- a/src/transformers/models/canine/configuration_canine.py +++ b/src/transformers/models/canine/configuration_canine.py @@ -61,8 +61,6 @@ class CanineConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. downsampling_rate (:obj:`int`, `optional`, defaults to 4): The rate at which to downsample the original character sequence length before applying the deep Transformer encoder. diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 18ca01031c..a13505d3a0 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -772,6 +772,7 @@ class CanineEncoder(nn.Module): for _ in range(config.num_hidden_layers) ] ) + self.gradient_checkpointing = False def forward( self, @@ -791,7 +792,7 @@ class CanineEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -895,6 +896,7 @@ class CaninePreTrainedModel(PreTrainedModel): config_class = CanineConfig load_tf_weights = load_tf_weights_in_canine base_model_prefix = "canine" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -913,6 +915,10 @@ class CaninePreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, CanineEncoder): + module.gradient_checkpointing = value + CANINE_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/clip/configuration_clip.py b/src/transformers/models/clip/configuration_clip.py index b824288711..0f8b6fa9a4 100644 --- a/src/transformers/models/clip/configuration_clip.py +++ b/src/transformers/models/clip/configuration_clip.py @@ -68,8 +68,6 @@ class CLIPTextConfig(PretrainedConfig): initializer_factor (:obj:`float`, `optional`, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -103,7 +101,6 @@ class CLIPTextConfig(PretrainedConfig): pad_token_id=1, bos_token_id=0, eos_token_id=2, - gradient_checkpointing=False, **kwargs ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -120,7 +117,6 @@ class CLIPTextConfig(PretrainedConfig): self.initializer_range = initializer_range self.initializer_factor = initializer_factor self.attention_dropout = attention_dropout - self.gradient_checkpointing = gradient_checkpointing class CLIPVisionConfig(PretrainedConfig): @@ -161,8 +157,6 @@ class CLIPVisionConfig(PretrainedConfig): initializer_factor (:obj:`float`, `optional`, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -194,7 +188,6 @@ class CLIPVisionConfig(PretrainedConfig): attention_dropout=0.0, initializer_range=0.02, initializer_factor=1.0, - gradient_checkpointing=False, **kwargs ): super().__init__(**kwargs) @@ -211,7 +204,6 @@ class CLIPVisionConfig(PretrainedConfig): self.attention_dropout = attention_dropout self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act - self.gradient_checkpointing = gradient_checkpointing class CLIPConfig(PretrainedConfig): diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 8d723e05fc..4f3b280a1b 100755 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -338,6 +338,7 @@ class CLIPPreTrainedModel(PreTrainedModel): config_class = CLIPConfig base_model_prefix = "clip" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -383,6 +384,10 @@ class CLIPPreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, CLIPEncoder): + module.gradient_checkpointing = value + CLIP_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use @@ -499,6 +504,7 @@ class CLIPEncoder(nn.Module): super().__init__() self.config = config self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -551,7 +557,7 @@ class CLIPEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index fbd0cdfc5e..99d8ae5dd4 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -248,6 +248,7 @@ class ConvBertPreTrainedModel(PreTrainedModel): config_class = ConvBertConfig load_tf_weights = load_tf_weights_in_convbert base_model_prefix = "convbert" + supports_gradient_checkpointing = True authorized_missing_keys = [r"position_ids"] authorized_unexpected_keys = [r"convbert\.embeddings_project\.weight", r"convbert\.embeddings_project\.bias"] @@ -267,6 +268,10 @@ class ConvBertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ConvBertEncoder): + module.gradient_checkpointing = value + class SeparableConv1D(nn.Module): """This class implements separable convolution, i.e. a depthwise and a pointwise layer""" @@ -603,6 +608,7 @@ class ConvBertEncoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([ConvBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -624,7 +630,7 @@ class ConvBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/deit/configuration_deit.py b/src/transformers/models/deit/configuration_deit.py index 0bbbff709b..98bbe1b01b 100644 --- a/src/transformers/models/deit/configuration_deit.py +++ b/src/transformers/models/deit/configuration_deit.py @@ -58,8 +58,6 @@ class DeiTConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. image_size (:obj:`int`, `optional`, defaults to :obj:`224`): The size (resolution) of each image. patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index b848376817..6ffa6afa3a 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -324,6 +324,7 @@ class DeiTEncoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -342,7 +343,7 @@ class DeiTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -384,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel): config_class = DeiTConfig base_model_prefix = "deit" + supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -401,6 +403,10 @@ class DeiTPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DeiTEncoder): + module.gradient_checkpointing = value + DEIT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 3061addada..af650e75e1 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -783,6 +783,7 @@ class DetrClassificationHead(nn.Module): class DetrPreTrainedModel(PreTrainedModel): config_class = DetrConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -807,6 +808,10 @@ class DetrPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DetrDecoder): + module.gradient_checkpointing = value + DETR_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic @@ -997,6 +1002,7 @@ class DetrDecoder(DetrPreTrainedModel): self.layernorm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -1084,7 +1090,7 @@ class DetrDecoder(DetrPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): continue - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/dpr/configuration_dpr.py b/src/transformers/models/dpr/configuration_dpr.py index 2773835f72..a9b5f96556 100644 --- a/src/transformers/models/dpr/configuration_dpr.py +++ b/src/transformers/models/dpr/configuration_dpr.py @@ -69,8 +69,6 @@ class DPRConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on @@ -99,7 +97,6 @@ class DPRConfig(PretrainedConfig): initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, position_embedding_type="absolute", projection_dim: int = 0, **kwargs @@ -118,6 +115,5 @@ class DPRConfig(PretrainedConfig): self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.projection_dim = projection_dim self.position_embedding_type = position_embedding_type diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index 37fa61b706..c1a3fa618d 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -30,7 +30,7 @@ from ...file_utils import ( from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import logging -from ..bert.modeling_bert import BertModel +from ..bert.modeling_bert import BertEncoder, BertModel from .configuration_dpr import DPRConfig @@ -300,6 +300,10 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel): def init_weights(self): self.question_encoder.init_weights() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + class DPRPretrainedReader(PreTrainedModel): """ @@ -317,6 +321,10 @@ class DPRPretrainedReader(PreTrainedModel): self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights) self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + ############### # Actual Models diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 867a7c0915..1f44b23522 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -527,6 +527,7 @@ class ElectraEncoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([ElectraLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -553,12 +554,11 @@ class ElectraEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -663,6 +663,7 @@ class ElectraPreTrainedModel(PreTrainedModel): config_class = ElectraConfig load_tf_weights = load_tf_weights_in_electra base_model_prefix = "electra" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"] @@ -683,6 +684,10 @@ class ElectraPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ElectraEncoder): + module.gradient_checkpointing = value + @dataclass class ElectraForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/fnet/configuration_fnet.py b/src/transformers/models/fnet/configuration_fnet.py index 047190a3ed..a6922f8355 100644 --- a/src/transformers/models/fnet/configuration_fnet.py +++ b/src/transformers/models/fnet/configuration_fnet.py @@ -64,8 +64,6 @@ class FNetConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. use_tpu_fourier_optimizations (:obj:`bool`, `optional`, defaults to :obj:`False`): Determines whether to use TPU optimized FFTs. If :obj:`True`, the model will favor axis-wise FFTs transforms. Set to :obj:`False` for GPU/CPU hardware, in which case n-dimensional FFTs are used. diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 2a1b7f5f2a..9340eb04f3 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -284,6 +284,7 @@ class FNetEncoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([FNetLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward(self, hidden_states, output_hidden_states=False, return_dict=True): all_hidden_states = () if output_hidden_states else None @@ -292,7 +293,7 @@ class FNetEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -413,6 +414,7 @@ class FNetPreTrainedModel(PreTrainedModel): config_class = FNetConfig base_model_prefix = "fnet" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -432,6 +434,10 @@ class FNetPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, FNetEncoder): + module.gradient_checkpointing = value + @dataclass class FNetForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/gpt2/configuration_gpt2.py b/src/transformers/models/gpt2/configuration_gpt2.py index f003023ca8..41120c94da 100644 --- a/src/transformers/models/gpt2/configuration_gpt2.py +++ b/src/transformers/models/gpt2/configuration_gpt2.py @@ -108,9 +108,7 @@ class GPT2Config(PretrainedConfig): The dropout ratio to be used after the projection and activation. scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`): - Scale attention weights by dividing by sqrt(hidden_size). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. + Scale attention weights by dividing by sqrt(hidden_size).. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). @@ -158,7 +156,6 @@ class GPT2Config(PretrainedConfig): summary_proj_to_labels=True, summary_first_dropout=0.1, scale_attn_weights=True, - gradient_checkpointing=False, use_cache=True, bos_token_id=50256, eos_token_id=50256, @@ -182,7 +179,6 @@ class GPT2Config(PretrainedConfig): self.summary_activation = summary_activation self.summary_first_dropout = summary_first_dropout self.summary_proj_to_labels = summary_proj_to_labels - self.gradient_checkpointing = gradient_checkpointing self.scale_attn_weights = scale_attn_weights self.use_cache = use_cache diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 43419e6615..d6fab7f7ff 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -374,6 +374,7 @@ class GPT2PreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_gpt2 base_model_prefix = "transformer" is_parallelizable = True + supports_gradient_checkpointing = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -394,6 +395,10 @@ class GPT2PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPT2Model): + module.gradient_checkpointing = value + @dataclass class GPT2DoubleHeadsModelOutput(ModelOutput): @@ -589,6 +594,7 @@ class GPT2Model(GPT2PreTrainedModel): # Model parallel self.model_parallel = False self.device_map = None + self.gradient_checkpointing = False @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -764,12 +770,11 @@ class GPT2Model(GPT2PreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/gpt_neo/configuration_gpt_neo.py b/src/transformers/models/gpt_neo/configuration_gpt_neo.py index e5b7e683d9..d5069fb017 100644 --- a/src/transformers/models/gpt_neo/configuration_gpt_neo.py +++ b/src/transformers/models/gpt_neo/configuration_gpt_neo.py @@ -79,8 +79,6 @@ class GPTNeoConfig(PretrainedConfig): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -120,7 +118,6 @@ class GPTNeoConfig(PretrainedConfig): summary_activation=None, summary_proj_to_labels=True, summary_first_dropout=0.1, - gradient_checkpointing=False, use_cache=True, bos_token_id=50256, eos_token_id=50256, @@ -144,7 +141,6 @@ class GPTNeoConfig(PretrainedConfig): self.summary_activation = summary_activation self.summary_first_dropout = summary_first_dropout self.summary_proj_to_labels = summary_proj_to_labels - self.gradient_checkpointing = gradient_checkpointing self.use_cache = use_cache self.bos_token_id = bos_token_id diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 353d3b0fb6..3fafd75ac2 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -361,6 +361,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): config_class = GPTNeoConfig load_tf_weights = load_tf_weights_in_gpt_neo base_model_prefix = "transformer" + supports_gradient_checkpointing = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -381,6 +382,10 @@ class GPTNeoPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPTNeoModel): + module.gradient_checkpointing = value + GPT_NEO_START_DOCSTRING = r""" @@ -482,6 +487,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.wte @@ -592,12 +598,11 @@ class GPTNeoModel(GPTNeoPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/gptj/configuration_gptj.py b/src/transformers/models/gptj/configuration_gptj.py index 93018fdcb6..61dfd4e663 100644 --- a/src/transformers/models/gptj/configuration_gptj.py +++ b/src/transformers/models/gptj/configuration_gptj.py @@ -68,8 +68,6 @@ class GPTJConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`): Scale attention weights by dividing by sqrt(hidden_size). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). @@ -111,7 +109,6 @@ class GPTJConfig(PretrainedConfig): layer_norm_epsilon=1e-5, initializer_range=0.02, scale_attn_weights=True, - gradient_checkpointing=False, use_cache=True, bos_token_id=50256, eos_token_id=50256, @@ -131,7 +128,6 @@ class GPTJConfig(PretrainedConfig): self.attn_pdrop = attn_pdrop self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range - self.gradient_checkpointing = gradient_checkpointing self.scale_attn_weights = scale_attn_weights self.use_cache = use_cache diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 2d7781a275..a23da08347 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -303,6 +303,7 @@ class GPTJPreTrainedModel(PreTrainedModel): config_class = GPTJConfig base_model_prefix = "transformer" is_parallelizable = True + supports_gradient_checkpointing = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -323,6 +324,10 @@ class GPTJPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPTJModel): + module.gradient_checkpointing = value + GPTJ_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use @@ -445,6 +450,7 @@ class GPTJModel(GPTJPreTrainedModel): # Model parallel self.model_parallel = False self.device_map = None + self.gradient_checkpointing = False @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -598,12 +604,11 @@ class GPTJModel(GPTJPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/hubert/configuration_hubert.py b/src/transformers/models/hubert/configuration_hubert.py index 633807684f..624211431c 100644 --- a/src/transformers/models/hubert/configuration_hubert.py +++ b/src/transformers/models/hubert/configuration_hubert.py @@ -120,8 +120,6 @@ class HubertConfig(PretrainedConfig): instance of :class:`~transformers.HubertForSequenceClassification`. classifier_proj_size (:obj:`int`, `optional`, defaults to 256): Dimensionality of the projection before token mean-pooling for classification. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -172,7 +170,6 @@ class HubertConfig(PretrainedConfig): ctc_zero_infinity=False, use_weighted_layer_sum=False, classifier_proj_size=256, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -203,7 +200,6 @@ class HubertConfig(PretrainedConfig): self.initializer_range = initializer_range self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm - self.gradient_checkpointing = gradient_checkpointing self.use_weighted_layer_sum = use_weighted_layer_sum self.classifier_proj_size = classifier_proj_size diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 95d5c91f5a..6575f4932b 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -525,6 +525,7 @@ class HubertEncoder(nn.Module): self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -564,7 +565,7 @@ class HubertEncoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): @@ -612,6 +613,7 @@ class HubertEncoderStableLayerNorm(nn.Module): self.layers = nn.ModuleList( [HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) + self.gradient_checkpointing = False def forward( self, @@ -651,7 +653,7 @@ class HubertEncoderStableLayerNorm(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): @@ -698,6 +700,7 @@ class HubertPreTrainedModel(PreTrainedModel): config_class = HubertConfig base_model_prefix = "hubert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -725,6 +728,10 @@ class HubertPreTrainedModel(PreTrainedModel): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)): + module.gradient_checkpointing = value + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Computes the output length of the convolutional layers diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index 61c775d9ff..d4f74ff47e 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -579,17 +579,13 @@ class IBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: - raise NotImplementedError("gradient checkpointing is not currently supported") - - else: - layer_outputs = layer_module( - hidden_states, - hidden_states_scaling_factor, - attention_mask, - layer_head_mask, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + hidden_states_scaling_factor, + attention_mask, + layer_head_mask, + output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/layoutlm/configuration_layoutlm.py b/src/transformers/models/layoutlm/configuration_layoutlm.py index 4c23cde9a5..61a6ce264d 100644 --- a/src/transformers/models/layoutlm/configuration_layoutlm.py +++ b/src/transformers/models/layoutlm/configuration_layoutlm.py @@ -71,8 +71,6 @@ class LayoutLMConfig(BertConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. max_2d_position_embeddings (:obj:`int`, `optional`, defaults to 1024): The maximum value that the 2D position embedding might ever used. Typically set this to something large just in case (e.g., 1024). @@ -108,7 +106,6 @@ class LayoutLMConfig(BertConfig): initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, max_2d_position_embeddings=1024, **kwargs ): @@ -126,7 +123,6 @@ class LayoutLMConfig(BertConfig): initializer_range=initializer_range, layer_norm_eps=layer_norm_eps, pad_token_id=pad_token_id, - gradient_checkpointing=gradient_checkpointing, **kwargs, ) self.max_2d_position_embeddings = max_2d_position_embeddings diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 3e7dfe8560..b47d2793d1 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -442,6 +442,7 @@ class LayoutLMEncoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -468,12 +469,11 @@ class LayoutLMEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -609,6 +609,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel): config_class = LayoutLMConfig pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST base_model_prefix = "layoutlm" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -627,6 +628,10 @@ class LayoutLMPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LayoutLMEncoder): + module.gradient_checkpointing = value + LAYOUTLM_START_DOCSTRING = r""" The LayoutLM model was proposed in `LayoutLM: Pre-training of Text and Layout for Document Image Understanding diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index e42d77bab2..6c42ce1ccc 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -378,6 +378,8 @@ class LayoutLMv2Encoder(nn.Module): self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) + self.gradient_checkpointing = False + def _calculate_1d_position_embeddings(self, hidden_states, position_ids): rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) rel_pos = relative_position_bucket( @@ -443,7 +445,7 @@ class LayoutLMv2Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -502,6 +504,7 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): config_class = LayoutLMv2Config pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST base_model_prefix = "layoutlmv2" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -520,6 +523,10 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LayoutLMv2Encoder): + module.gradient_checkpointing = value + def my_convert_sync_batchnorm(module, process_group=None): # same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d` diff --git a/src/transformers/models/led/configuration_led.py b/src/transformers/models/led/configuration_led.py index 5992d275ed..e30c3e04c4 100644 --- a/src/transformers/models/led/configuration_led.py +++ b/src/transformers/models/led/configuration_led.py @@ -82,8 +82,6 @@ class LEDConfig(PretrainedConfig): https://arxiv.org/abs/1909.11556>`__ for more details. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models) - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -132,7 +130,6 @@ class LEDConfig(PretrainedConfig): pad_token_id=1, bos_token_id=0, eos_token_id=2, - gradient_checkpointing=False, attention_window: Union[List[int], int] = 512, **kwargs ): @@ -157,7 +154,6 @@ class LEDConfig(PretrainedConfig): self.use_cache = use_cache self.num_hidden_layers = encoder_layers self.attention_window = attention_window - self.gradient_checkpointing = gradient_checkpointing super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index c1c5af6d1e..926da161a9 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1077,6 +1077,7 @@ class LEDClassificationHead(nn.Module): class LEDPreTrainedModel(PreTrainedModel): config_class = LEDConfig base_model_prefix = "led" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -1089,6 +1090,10 @@ class LEDPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LEDDecoder, LEDEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -1625,6 +1630,7 @@ class LEDEncoder(LEDPreTrainedModel): self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor): # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) @@ -1809,7 +1815,7 @@ class LEDEncoder(LEDPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -1894,6 +1900,7 @@ class LEDDecoder(LEDPreTrainedModel): self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -2061,12 +2068,11 @@ class LEDDecoder(LEDPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 6fbdfb12f5..3e327c5c68 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1231,6 +1231,7 @@ class LongformerEncoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -1259,7 +1260,7 @@ class LongformerEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -1363,6 +1364,7 @@ class LongformerPreTrainedModel(PreTrainedModel): config_class = LongformerConfig base_model_prefix = "longformer" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -1381,6 +1383,10 @@ class LongformerPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LongformerEncoder): + module.gradient_checkpointing = value + LONGFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/luke/configuration_luke.py b/src/transformers/models/luke/configuration_luke.py index befd3e45e5..ba6dc49643 100644 --- a/src/transformers/models/luke/configuration_luke.py +++ b/src/transformers/models/luke/configuration_luke.py @@ -68,8 +68,6 @@ class LukeConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. use_entity_aware_attention (:obj:`bool`, defaults to :obj:`True`): Whether or not the model should use the entity-aware self-attention mechanism proposed in `LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention (Yamada et al.) @@ -106,7 +104,6 @@ class LukeConfig(PretrainedConfig): type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, - gradient_checkpointing=False, use_entity_aware_attention=True, pad_token_id=1, bos_token_id=0, @@ -130,5 +127,4 @@ class LukeConfig(PretrainedConfig): self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.use_entity_aware_attention = use_entity_aware_attention diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index b9004c1d49..97d1f1adfd 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -579,6 +579,7 @@ class LukeEncoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([LukeLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -600,7 +601,7 @@ class LukeEncoder(nn.Module): all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -681,6 +682,7 @@ class LukePreTrainedModel(PreTrainedModel): config_class = LukeConfig base_model_prefix = "luke" + supports_gradient_checkpointing = True def _init_weights(self, module: nn.Module): """Initialize the weights""" @@ -699,6 +701,10 @@ class LukePreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LukeEncoder): + module.gradient_checkpointing = value + LUKE_START_DOCSTRING = r""" diff --git a/src/transformers/models/m2m_100/configuration_m2m_100.py b/src/transformers/models/m2m_100/configuration_m2m_100.py index 765bcb4cd1..a4a0df749c 100644 --- a/src/transformers/models/m2m_100/configuration_m2m_100.py +++ b/src/transformers/models/m2m_100/configuration_m2m_100.py @@ -79,8 +79,6 @@ class M2M100Config(PretrainedConfig): https://arxiv.org/abs/1909.11556>`__ for more details. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -121,7 +119,6 @@ class M2M100Config(PretrainedConfig): init_std=0.02, decoder_start_token_id=2, scale_embedding=True, - gradient_checkpointing=False, pad_token_id=1, bos_token_id=0, eos_token_id=2, @@ -145,7 +142,6 @@ class M2M100Config(PretrainedConfig): self.decoder_layerdrop = decoder_layerdrop self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index ce86fe2c77..9bb15c918a 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -520,6 +520,7 @@ class M2M100DecoderLayer(nn.Module): class M2M100PreTrainedModel(PreTrainedModel): config_class = M2M100Config base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -532,6 +533,10 @@ class M2M100PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (M2M100Decoder, M2M100Encoder)): + module.gradient_checkpointing = value + M2M_100_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic @@ -693,6 +698,7 @@ class M2M100Encoder(M2M100PreTrainedModel): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -787,7 +793,7 @@ class M2M100Encoder(M2M100PreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -857,6 +863,7 @@ class M2M100Decoder(M2M100PreTrainedModel): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -1013,12 +1020,11 @@ class M2M100Decoder(M2M100PreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/marian/configuration_marian.py b/src/transformers/models/marian/configuration_marian.py index 1b974badfa..825c7d707a 100644 --- a/src/transformers/models/marian/configuration_marian.py +++ b/src/transformers/models/marian/configuration_marian.py @@ -78,8 +78,6 @@ class MarianConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -128,7 +126,6 @@ class MarianConfig(PretrainedConfig): decoder_start_token_id=58100, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=58100, eos_token_id=0, forced_eos_token_id=0, @@ -153,7 +150,6 @@ class MarianConfig(PretrainedConfig): self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index e2feb549b7..a2df637350 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -466,6 +466,7 @@ class MarianDecoderLayer(nn.Module): class MarianPreTrainedModel(PreTrainedModel): config_class = MarianConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -480,6 +481,10 @@ class MarianPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (MarianDecoder, MarianEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -656,6 +661,7 @@ class MarianEncoder(MarianPreTrainedModel): ) self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)]) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -750,7 +756,7 @@ class MarianEncoder(MarianPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -816,6 +822,7 @@ class MarianDecoder(MarianPreTrainedModel): ) self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)]) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -987,12 +994,11 @@ class MarianDecoder(MarianPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/mbart/configuration_mbart.py b/src/transformers/models/mbart/configuration_mbart.py index 05857241b4..d1eb27c0e8 100644 --- a/src/transformers/models/mbart/configuration_mbart.py +++ b/src/transformers/models/mbart/configuration_mbart.py @@ -82,8 +82,6 @@ class MBartConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -131,7 +129,6 @@ class MBartConfig(PretrainedConfig): init_std=0.02, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=1, bos_token_id=0, eos_token_id=2, @@ -157,7 +154,6 @@ class MBartConfig(PretrainedConfig): self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 0412eccaaa..0ebb5a1a8f 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -479,6 +479,7 @@ class MBartClassificationHead(nn.Module): class MBartPreTrainedModel(PreTrainedModel): config_class = MBartConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -491,6 +492,10 @@ class MBartPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (MBartDecoder, MBartDecoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -685,6 +690,7 @@ class MBartEncoder(MBartPreTrainedModel): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -780,7 +786,7 @@ class MBartEncoder(MBartPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -850,6 +856,7 @@ class MBartDecoder(MBartPreTrainedModel): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -1022,12 +1029,11 @@ class MBartDecoder(MBartPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/megatron_bert/configuration_megatron_bert.py b/src/transformers/models/megatron_bert/configuration_megatron_bert.py index 19171e70da..d6e32cd496 100644 --- a/src/transformers/models/megatron_bert/configuration_megatron_bert.py +++ b/src/transformers/models/megatron_bert/configuration_megatron_bert.py @@ -65,8 +65,6 @@ class MegatronBertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on @@ -108,7 +106,6 @@ class MegatronBertConfig(PretrainedConfig): initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, position_embedding_type="absolute", use_cache=True, **kwargs @@ -127,6 +124,5 @@ class MegatronBertConfig(PretrainedConfig): self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache diff --git a/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py b/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py index 3d7f03dcbb..1d33ef91e6 100644 --- a/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py +++ b/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py @@ -180,7 +180,6 @@ def convert_megatron_checkpoint(args, input_state_dict): "type_vocab_size": 2, "initializer_range": 0.2, "layer_norm_eps": 1e-12, - "gradient_checkpointing": False, "position_embedding_type": "absolute", "use_cache": False, } diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 3c49ea88b8..80337b2dab 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -508,6 +508,7 @@ class MegatronBertEncoder(nn.Module): # The final layer norm. We removed the 1st LN, moved LN to each hidden layer and this one # is simply the final LN (Transformer's BERT has it attached to each hidden layer). self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.gradient_checkpointing = False def forward( self, @@ -534,12 +535,11 @@ class MegatronBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warn( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -705,6 +705,7 @@ class MegatronBertPreTrainedModel(PreTrainedModel): config_class = MegatronBertConfig load_tf_weights = load_tf_weights_in_megatron_bert base_model_prefix = "bert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -719,6 +720,10 @@ class MegatronBertPreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, MegatronBertEncoder): + module.gradient_checkpointing = value + @dataclass # Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->MegatronBert diff --git a/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py b/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py index 72271885bb..57d42a1171 100644 --- a/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py +++ b/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py @@ -279,7 +279,6 @@ def main(): summary_proj_to_labels=True, summary_first_dropout=0.1, scale_attn_weights=True, - gradient_checkpointing=False, use_cache=True, bos_token_id=50256, eos_token_id=50256, diff --git a/src/transformers/models/pegasus/configuration_pegasus.py b/src/transformers/models/pegasus/configuration_pegasus.py index 2e815c2e48..8cf76c482b 100644 --- a/src/transformers/models/pegasus/configuration_pegasus.py +++ b/src/transformers/models/pegasus/configuration_pegasus.py @@ -78,8 +78,6 @@ class PegasusConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -128,7 +126,6 @@ class PegasusConfig(PretrainedConfig): decoder_start_token_id=0, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=0, eos_token_id=1, forced_eos_token_id=1, @@ -153,7 +150,6 @@ class PegasusConfig(PretrainedConfig): self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index ab1009b339..2728f144b3 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -466,6 +466,7 @@ class PegasusDecoderLayer(nn.Module): class PegasusPreTrainedModel(PreTrainedModel): config_class = PegasusConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -480,6 +481,10 @@ class PegasusPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (PegasusDecoder, PegasusEncoder)): + module.gradient_checkpointing = value + PEGASUS_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic @@ -646,6 +651,7 @@ class PegasusEncoder(PegasusPreTrainedModel): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def resize_position_embeddings(self, new_num_position_embeddings: int): """ @@ -770,7 +776,7 @@ class PegasusEncoder(PegasusPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -840,6 +846,7 @@ class PegasusDecoder(PegasusPreTrainedModel): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -1040,12 +1047,11 @@ class PegasusDecoder(PegasusPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/prophetnet/configuration_prophetnet.py b/src/transformers/models/prophetnet/configuration_prophetnet.py index c19e4a106f..074bad3e24 100644 --- a/src/transformers/models/prophetnet/configuration_prophetnet.py +++ b/src/transformers/models/prophetnet/configuration_prophetnet.py @@ -92,8 +92,6 @@ class ProphetNetConfig(PretrainedConfig): smoothing is performed. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. """ model_type = "prophetnet" keys_to_ignore_at_inference = ["past_key_values"] @@ -124,7 +122,6 @@ class ProphetNetConfig(PretrainedConfig): num_buckets=32, relative_max_distance=128, disable_ngram_loss=False, - gradient_checkpointing=False, eps=0.0, use_cache=True, pad_token_id=0, @@ -158,9 +155,6 @@ class ProphetNetConfig(PretrainedConfig): self.use_cache = use_cache - # 4 Training Args (should be removed soon) - self.gradient_checkpointing = gradient_checkpointing - super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index ed4c792657..9f72a35f0d 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -547,6 +547,7 @@ class ProphetNetDecoderLMOutput(ModelOutput): class ProphetNetPreTrainedModel(PreTrainedModel): config_class = ProphetNetConfig base_model_prefix = "prophetnet" + supports_gradient_checkpointing = True def _init_weights(self, module): if isinstance(module, nn.Linear): @@ -558,6 +559,10 @@ class ProphetNetPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (ProphetNetDecoder, ProphetNetEncoder)): + module.gradient_checkpointing = value + def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -1262,6 +1267,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)]) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.word_embeddings @@ -1337,7 +1343,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): if output_hidden_states: encoder_hidden_states = encoder_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -1406,6 +1412,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): self.embeddings_layer_norm = LayerNorm(config.hidden_size) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.word_embeddings @@ -1566,12 +1573,11 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/rembert/configuration_rembert.py b/src/transformers/models/rembert/configuration_rembert.py index d9432d20a9..51c899dfc9 100644 --- a/src/transformers/models/rembert/configuration_rembert.py +++ b/src/transformers/models/rembert/configuration_rembert.py @@ -76,8 +76,6 @@ class RemBertConfig(PretrainedConfig): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 46524ce9cb..ab3874865a 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -501,6 +501,7 @@ class RemBertEncoder(nn.Module): self.embedding_hidden_mapping_in = nn.Linear(config.input_embedding_size, config.hidden_size) self.layer = nn.ModuleList([RemBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -528,12 +529,11 @@ class RemBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warn( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -648,6 +648,7 @@ class RemBertPreTrainedModel(PreTrainedModel): config_class = RemBertConfig load_tf_weights = load_tf_weights_in_rembert base_model_prefix = "rembert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -666,6 +667,10 @@ class RemBertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, RemBertEncoder): + module.gradient_checkpointing = value + REMBERT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 09472da767..f74954ac64 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -469,6 +469,7 @@ class RobertaEncoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -495,12 +496,11 @@ class RobertaEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -585,6 +585,7 @@ class RobertaPreTrainedModel(PreTrainedModel): config_class = RobertaConfig base_model_prefix = "roberta" + supports_gradient_checkpointing = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -603,6 +604,10 @@ class RobertaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, RobertaEncoder): + module.gradient_checkpointing = value + def update_keys_to_ignore(self, config, del_keys_to_ignore): """Remove some keys from ignore list""" if not config.tie_word_embeddings: diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py index 945d1064a1..5027b3be1f 100644 --- a/src/transformers/models/roformer/configuration_roformer.py +++ b/src/transformers/models/roformer/configuration_roformer.py @@ -80,8 +80,6 @@ class RoFormerConfig(PretrainedConfig): relevant if ``config.is_decoder=True``. rotary_value (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not apply rotary position embeddings on value layer. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -114,7 +112,6 @@ class RoFormerConfig(PretrainedConfig): initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, rotary_value=False, use_cache=True, **kwargs @@ -134,6 +131,5 @@ class RoFormerConfig(PretrainedConfig): self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.rotary_value = rotary_value self.use_cache = use_cache diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index f08d3e5c8f..23929a4c61 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -551,6 +551,7 @@ class RoFormerEncoder(nn.Module): config.max_position_embeddings, config.hidden_size // config.num_attention_heads ) self.layer = nn.ModuleList([RoFormerLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -580,12 +581,11 @@ class RoFormerEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -705,6 +705,7 @@ class RoFormerPreTrainedModel(PreTrainedModel): config_class = RoFormerConfig load_tf_weights = load_tf_weights_in_roformer base_model_prefix = "roformer" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [] _keys_to_ignore_on_load_unexpected = [ r"roformer\.embeddings_project\.weight", @@ -729,6 +730,10 @@ class RoFormerPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, RoFormerEncoder): + module.gradient_checkpointing = value + ROFORMER_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/speech_to_text/configuration_speech_to_text.py b/src/transformers/models/speech_to_text/configuration_speech_to_text.py index ff16601030..821362d2e6 100644 --- a/src/transformers/models/speech_to_text/configuration_speech_to_text.py +++ b/src/transformers/models/speech_to_text/configuration_speech_to_text.py @@ -134,7 +134,6 @@ class Speech2TextConfig(PretrainedConfig): decoder_start_token_id=2, classifier_dropout=0.0, scale_embedding=True, - gradient_checkpointing=False, pad_token_id=1, bos_token_id=0, eos_token_id=2, @@ -165,7 +164,6 @@ class Speech2TextConfig(PretrainedConfig): self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.max_source_positions = max_source_positions self.max_target_positions = max_target_positions diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index ce19d680ab..e91af884c6 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -531,6 +531,7 @@ class Speech2TextDecoderLayer(nn.Module): class Speech2TextPreTrainedModel(PreTrainedModel): config_class = Speech2TextConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -543,6 +544,10 @@ class Speech2TextPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Speech2TextDecoder, Speech2TextEncoder)): + module.gradient_checkpointing = value + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers @@ -711,6 +716,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -795,7 +801,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -863,6 +869,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -1032,11 +1039,11 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." ) use_cache = False diff --git a/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py index f1f9505990..abeac09a10 100644 --- a/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py @@ -108,7 +108,6 @@ class Speech2Text2Config(PretrainedConfig): decoder_start_token_id=2, classifier_dropout=0.0, scale_embedding=True, - gradient_checkpointing=False, pad_token_id=1, bos_token_id=0, eos_token_id=2, @@ -130,7 +129,6 @@ class Speech2Text2Config(PretrainedConfig): self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = decoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.max_source_positions = max_source_positions self.max_target_positions = max_target_positions diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index 848df757c3..fbbbaa3cbf 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -407,6 +407,7 @@ class Speech2Text2DecoderLayer(nn.Module): class Speech2Text2PreTrainedModel(PreTrainedModel): config_class = Speech2Text2Config base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -419,6 +420,10 @@ class Speech2Text2PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, Speech2Text2Decoder): + module.gradient_checkpointing = value + SPEECH_TO_TEXT_2_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic @@ -465,6 +470,7 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): self.layers = nn.ModuleList([Speech2Text2DecoderLayer(config) for _ in range(config.decoder_layers)]) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -635,11 +641,11 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." ) use_cache = False diff --git a/src/transformers/models/splinter/configuration_splinter.py b/src/transformers/models/splinter/configuration_splinter.py index 879451bbe5..986e436fe7 100644 --- a/src/transformers/models/splinter/configuration_splinter.py +++ b/src/transformers/models/splinter/configuration_splinter.py @@ -71,8 +71,6 @@ class SplinterConfig(PretrainedConfig): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. question_token_id (:obj:`int`, `optional`, defaults to 104): The id of the ``[QUESTION]`` token. diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 1296db1250..381a280ebb 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -409,6 +409,7 @@ class SplinterEncoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -435,12 +436,11 @@ class SplinterEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -509,6 +509,7 @@ class SplinterPreTrainedModel(PreTrainedModel): config_class = SplinterConfig base_model_prefix = "splinter" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights @@ -528,6 +529,10 @@ class SplinterPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, SplinterEncoder): + module.gradient_checkpointing = value + SPLINTER_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index 9a40659127..bb16a5fb0f 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -77,8 +77,6 @@ class T5Config(PretrainedConfig): the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. """ model_type = "t5" keys_to_ignore_at_inference = ["past_key_values"] @@ -102,7 +100,6 @@ class T5Config(PretrainedConfig): use_cache=True, pad_token_id=0, eos_token_id=1, - gradient_checkpointing=False, **kwargs ): self.vocab_size = vocab_size @@ -120,7 +117,6 @@ class T5Config(PretrainedConfig): self.initializer_factor = initializer_factor self.feed_forward_proj = feed_forward_proj self.use_cache = use_cache - self.gradient_checkpointing = gradient_checkpointing super().__init__( pad_token_id=pad_token_id, eos_token_id=eos_token_id, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 27ef440bfb..f18c9e66f5 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -325,7 +325,7 @@ class T5Attention(nn.Module): if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.pruned_heads = set() - self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) + self.gradient_checkpointing = False def prune_heads(self, heads): if len(heads) == 0: @@ -489,7 +489,7 @@ class T5Attention(nn.Module): position_bias = torch.zeros( (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype ) - if self.training and self.gradient_checkpointing: + if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: position_bias = self.compute_bias(real_seq_length, key_length) @@ -715,6 +715,7 @@ class T5PreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_t5 base_model_prefix = "transformer" is_parallelizable = True + supports_gradient_checkpointing = True @property def dummy_inputs(self): @@ -769,6 +770,10 @@ class T5PreTrainedModel(PreTrainedModel): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (T5Attention, T5Stack)): + module.gradient_checkpointing = value + def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -813,6 +818,7 @@ class T5Stack(T5PreTrainedModel): # Model parallel self.model_parallel = False self.device_map = None + self.gradient_checkpointing = False @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -968,11 +974,10 @@ class T5Stack(T5PreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warn( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/tapas/configuration_tapas.py b/src/transformers/models/tapas/configuration_tapas.py index 834cae0c7e..d59dc00f45 100644 --- a/src/transformers/models/tapas/configuration_tapas.py +++ b/src/transformers/models/tapas/configuration_tapas.py @@ -73,8 +73,6 @@ class TapasConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to use gradient checkpointing to save memory at the expense of a slower backward pass. positive_label_weight (:obj:`float`, `optional`, defaults to 10.0): Weight for positive labels. num_aggregation_labels (:obj:`int`, `optional`, defaults to 0): @@ -159,7 +157,6 @@ class TapasConfig(PretrainedConfig): initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, positive_label_weight=10.0, num_aggregation_labels=0, aggregation_loss_weight=1.0, @@ -202,7 +199,6 @@ class TapasConfig(PretrainedConfig): self.type_vocab_sizes = type_vocab_sizes self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing # Fine-tuning task hyperparameters self.positive_label_weight = positive_label_weight diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 29d4a3ef4f..9506216522 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -627,6 +627,7 @@ class TapasEncoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([TapasLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -649,7 +650,7 @@ class TapasEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -763,6 +764,7 @@ class TapasPreTrainedModel(PreTrainedModel): config_class = TapasConfig base_model_prefix = "tapas" + supports_gradient_checkpointing = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -781,6 +783,10 @@ class TapasPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, TapasEncoder): + module.gradient_checkpointing = value + TAPAS_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 21f5e01362..c6c0101008 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -398,6 +398,7 @@ class VisualBertEncoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -417,7 +418,7 @@ class VisualBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -532,6 +533,7 @@ class VisualBertPreTrainedModel(PreTrainedModel): config_class = VisualBertConfig base_model_prefix = "visual_bert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -547,6 +549,10 @@ class VisualBertPreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, VisualBertEncoder): + module.gradient_checkpointing = value + @dataclass class VisualBertForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/vit/configuration_vit.py b/src/transformers/models/vit/configuration_vit.py index 5e53df4cdd..9c64be5141 100644 --- a/src/transformers/models/vit/configuration_vit.py +++ b/src/transformers/models/vit/configuration_vit.py @@ -57,8 +57,6 @@ class ViTConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. image_size (:obj:`int`, `optional`, defaults to :obj:`224`): The size (resolution) of each image. patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 5b147f2856..78911f7b41 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -352,6 +352,7 @@ class ViTEncoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -370,7 +371,7 @@ class ViTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -411,6 +412,7 @@ class ViTPreTrainedModel(PreTrainedModel): config_class = ViTConfig base_model_prefix = "vit" + supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -428,6 +430,10 @@ class ViTPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ViTEncoder): + module.gradient_checkpointing = value + VIT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use diff --git a/src/transformers/models/wav2vec2/configuration_wav2vec2.py b/src/transformers/models/wav2vec2/configuration_wav2vec2.py index d82e6a6d34..49818feb22 100644 --- a/src/transformers/models/wav2vec2/configuration_wav2vec2.py +++ b/src/transformers/models/wav2vec2/configuration_wav2vec2.py @@ -138,8 +138,6 @@ class Wav2Vec2Config(PretrainedConfig): instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`. classifier_proj_size (:obj:`int`, `optional`, defaults to 256): Dimensionality of the projection before token mean-pooling for classification. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -198,7 +196,6 @@ class Wav2Vec2Config(PretrainedConfig): ctc_zero_infinity=False, use_weighted_layer_sum=False, classifier_proj_size=256, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -229,7 +226,6 @@ class Wav2Vec2Config(PretrainedConfig): self.initializer_range = initializer_range self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm - self.gradient_checkpointing = gradient_checkpointing self.use_weighted_layer_sum = use_weighted_layer_sum self.classifier_proj_size = classifier_proj_size diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ade54417f1..71f431ca97 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -590,6 +590,7 @@ class Wav2Vec2Encoder(nn.Module): self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -629,7 +630,7 @@ class Wav2Vec2Encoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): @@ -676,6 +677,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): self.layers = nn.ModuleList( [Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) + self.gradient_checkpointing = False def forward( self, @@ -715,7 +717,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): @@ -842,6 +844,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): config_class = Wav2Vec2Config base_model_prefix = "wav2vec2" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -864,6 +867,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm)): + module.gradient_checkpointing = value + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Computes the output length of the convolutional layers diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f5aa74616c..d39a24bf46 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -990,7 +990,7 @@ class Trainer: elif isinstance(model, PreTrainedModel): # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 - find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False) + find_unused_parameters = not getattr(model.config, "_gradient_checkpointing", False) else: find_unused_parameters = True model = nn.parallel.DistributedDataParallel( @@ -1162,6 +1162,10 @@ class Trainer: self.state = TrainerState() self.state.is_hyper_param_search = trial is not None + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + model = self._wrap_model(self.model_wrapped) # for the rest of this function `model` is the outside model, whether it was wrapped or not diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index d34622abc0..ce330a254c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -372,6 +372,8 @@ class TrainingArguments: hub_token (:obj:`str`, `optional`): The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with :obj:`huggingface-cli login`. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. """ output_dir: str = field( @@ -650,6 +652,12 @@ class TrainingArguments: metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."}, ) hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) + gradient_checkpointing: bool = field( + default=False, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) # Deprecated arguments push_to_hub_model_id: str = field( default=None, metadata={"help": "The name of the repository to which push the `Trainer`."} diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py index 93da35a5d9..6978a3ddf3 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py @@ -72,8 +72,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. {% else -%} vocab_size (:obj:`int`, `optional`, defaults to 50265): Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the @@ -186,7 +184,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig): decoder_start_token_id=2, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, {% endif -%} pad_token_id=1, bos_token_id=0, @@ -225,7 +222,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig): self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True {% endif -%} diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 835382396c..b0482f7062 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -513,6 +513,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): super().__init__() self.config = config self.layer = nn.ModuleList([{{cookiecutter.camelcase_modelname}}Layer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -539,12 +540,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -664,6 +664,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): config_class = {{cookiecutter.camelcase_modelname}}Config load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}} base_model_prefix = "{{cookiecutter.lowercase_modelname}}" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -682,6 +683,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder): + module.gradient_checkpointing = value + {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. @@ -2006,6 +2011,7 @@ class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module): class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): config_class = {{cookiecutter.camelcase_modelname}}Config base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -2017,16 +2023,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - - @property - def dummy_inputs(self): - pad_token = self.config.pad_token_id - input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) - dummy_inputs = { - "attention_mask": input_ids.ne(pad_token), - "input_ids": input_ids, - } - return dummy_inputs + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)): + module.gradient_checkpointing = value {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -2213,6 +2213,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -2309,7 +2310,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -2376,6 +2377,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -2545,10 +2547,10 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: - logger.warning("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...") + logger.warning("`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...") use_cache = False def create_custom_forward(module): diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index b4c670356f..6557936d59 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -224,6 +224,27 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): loss = model(**inputs).loss loss.backward() + def test_training_gradient_checkpointing(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if not self.model_tester.is_training: + return + + config.use_cache = False + config.return_dict = True + + for model_class in self.all_model_classes: + if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing: + continue + # we don't test BeitForMaskedImageModeling + if model_class.__name__ == "BeitForMaskedImageModeling": + continue + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f1a11871b0..b61cf834fb 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -370,15 +370,14 @@ class ModelTesterMixin: def test_training_gradient_checkpointing(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - if not self.model_tester.is_training or not hasattr(config, "gradient_checkpointing"): + if not self.model_tester.is_training: return - config.gradient_checkpointing = True config.use_cache = False config.return_dict = True for model_class in self.all_model_classes: - if model_class in get_values(MODEL_MAPPING): + if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing: continue model = model_class(config) model.to(torch_device) diff --git a/tests/test_modeling_deit.py b/tests/test_modeling_deit.py index c689a90af7..119e098891 100644 --- a/tests/test_modeling_deit.py +++ b/tests/test_modeling_deit.py @@ -20,6 +20,7 @@ import unittest from transformers import DeiTConfig from transformers.file_utils import cached_property, is_torch_available, is_vision_available +from transformers.models.auto import get_values from transformers.testing_utils import require_torch, require_vision, slow, torch_device from .test_configuration_common import ConfigTester @@ -340,7 +341,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): for model_class in self.all_model_classes: # DeiTForImageClassificationWithTeacher supports inference-only if ( - model_class in MODEL_MAPPING.values() + model_class in get_values(MODEL_MAPPING) or model_class.__name__ == "DeiTForImageClassificationWithTeacher" ): continue @@ -351,6 +352,27 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): loss = model(**inputs).loss loss.backward() + def test_training_gradient_checkpointing(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if not self.model_tester.is_training: + return + + config.use_cache = False + config.return_dict = True + + for model_class in self.all_model_classes: + if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing: + continue + # DeiTForImageClassificationWithTeacher supports inference-only + if model_class.__name__ == "DeiTForImageClassificationWithTeacher": + continue + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) diff --git a/tests/test_modeling_flax_gpt2.py b/tests/test_modeling_flax_gpt2.py index 0c793ebd27..3b2e43680e 100644 --- a/tests/test_modeling_flax_gpt2.py +++ b/tests/test_modeling_flax_gpt2.py @@ -82,7 +82,7 @@ class FlaxGPT2ModelTester: self.eos_token_id = vocab_size - 1 self.pad_token_id = vocab_size - 1 - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -100,7 +100,6 @@ class FlaxGPT2ModelTester: bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, - gradient_checkpointing=gradient_checkpointing, ) return (config, input_ids, input_mask) diff --git a/tests/test_modeling_flax_gpt_neo.py b/tests/test_modeling_flax_gpt_neo.py index 2916bec5b9..7d0d832295 100644 --- a/tests/test_modeling_flax_gpt_neo.py +++ b/tests/test_modeling_flax_gpt_neo.py @@ -86,7 +86,7 @@ class FlaxGPTNeoModelTester: self.eos_token_id = vocab_size - 1 self.pad_token_id = vocab_size - 1 - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -105,7 +105,6 @@ class FlaxGPTNeoModelTester: pad_token_id=self.pad_token_id, window_size=self.window_size, attention_types=self.attention_types, - gradient_checkpointing=gradient_checkpointing, ) return (config, input_ids, input_mask) diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 91d2edcdc8..214a17f050 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -96,7 +96,7 @@ class GPT2ModelTester: def get_large_model_config(self): return GPT2Config.from_pretrained("gpt2") - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -119,7 +119,7 @@ class GPT2ModelTester: token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) - config = self.get_config(gradient_checkpointing=gradient_checkpointing) + config = self.get_config() head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) @@ -135,7 +135,7 @@ class GPT2ModelTester: choice_labels, ) - def get_config(self, gradient_checkpointing=False): + def get_config(self): return GPT2Config( vocab_size=self.vocab_size, n_embd=self.hidden_size, @@ -149,11 +149,10 @@ class GPT2ModelTester: n_ctx=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range, - use_cache=not gradient_checkpointing, + use_cache=True, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, - gradient_checkpointing=gradient_checkpointing, ) def prepare_config_and_inputs_for_decoder(self): @@ -322,9 +321,13 @@ class GPT2ModelTester: self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + def create_and_check_forward_and_backwards( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + ): model = GPT2LMHeadModel(config) model.to(torch_device) + if gradient_checkpointing: + model.gradient_checkpointing_enable() result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) self.parent.assertEqual(result.loss.shape, ()) @@ -478,8 +481,8 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): self.model_tester.create_and_check_gpt2_for_token_classification(*config_and_inputs) def test_gpt2_gradient_checkpointing(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) - self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) @slow def test_batch_generation(self): @@ -612,7 +615,11 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): @slow def test_lm_generate_gpt2(self): for checkpointing in [True, False]: - model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing) + model = GPT2LMHeadModel.from_pretrained("gpt2") + if checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() model.to(torch_device) input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog expected_output_ids = [ diff --git a/tests/test_modeling_gpt_neo.py b/tests/test_modeling_gpt_neo.py index fa1b63b4f6..a8e5b4babc 100644 --- a/tests/test_modeling_gpt_neo.py +++ b/tests/test_modeling_gpt_neo.py @@ -97,7 +97,7 @@ class GPTNeoModelTester: def get_large_model_config(self): return GPTNeoConfig.from_pretrained("gpt_neo") - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -120,7 +120,7 @@ class GPTNeoModelTester: token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) - config = self.get_config(gradient_checkpointing=False) + config = self.get_config() head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) @@ -136,18 +136,17 @@ class GPTNeoModelTester: choice_labels, ) - def get_config(self, gradient_checkpointing=False): + def get_config(self): return GPTNeoConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_layers=self.num_hidden_layers, num_heads=self.num_attention_heads, max_position_embeddings=self.max_position_embeddings, - use_cache=not gradient_checkpointing, + use_cache=True, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, - gradient_checkpointing=gradient_checkpointing, window_size=self.window_size, attention_types=self.attention_types, ) @@ -329,8 +328,12 @@ class GPTNeoModelTester: result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) - def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + def create_and_check_forward_and_backwards( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + ): model = GPTNeoForCausalLM(config) + if gradient_checkpointing: + model.gradient_checkpointing_enable() model.to(torch_device) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) @@ -411,8 +414,8 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs) def test_gpt_neo_gradient_checkpointing(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) - self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) def _get_hidden_states(self): return torch.tensor( @@ -473,7 +476,10 @@ class GPTNeoModelLanguageGenerationTest(unittest.TestCase): def test_lm_generate_gpt_neo(self): for checkpointing in [True, False]: model = self.model - model.config.gradient_checkpointing = checkpointing + if checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog # fmt: off # The dog-eared copy of the book, which is a collection of essays by the late author, diff --git a/tests/test_modeling_gptj.py b/tests/test_modeling_gptj.py index 5739aed5a1..06979a2c7f 100644 --- a/tests/test_modeling_gptj.py +++ b/tests/test_modeling_gptj.py @@ -92,7 +92,7 @@ class GPTJModelTester: def get_large_model_config(self): return GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B") - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -115,7 +115,7 @@ class GPTJModelTester: token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) - config = self.get_config(gradient_checkpointing=gradient_checkpointing) + config = self.get_config() head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) @@ -131,7 +131,7 @@ class GPTJModelTester: choice_labels, ) - def get_config(self, gradient_checkpointing=False): + def get_config(self): return GPTJConfig( vocab_size=self.vocab_size, n_embd=self.hidden_size, @@ -145,11 +145,10 @@ class GPTJModelTester: n_ctx=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range, - use_cache=not gradient_checkpointing, + use_cache=True, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, - gradient_checkpointing=gradient_checkpointing, ) def prepare_config_and_inputs_for_decoder(self): @@ -318,8 +317,12 @@ class GPTJModelTester: self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + def create_and_check_forward_and_backwards( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + ): model = GPTJForCausalLM(config) + if gradient_checkpointing: + model.gradient_checkpointing_enable() model.to(torch_device) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) @@ -390,8 +393,8 @@ class GPTJModelTest(unittest.TestCase): self.model_tester.create_and_check_lm_head_model(*config_and_inputs) def test_gptj_gradient_checkpointing(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) - self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) @slow def test_batch_generation(self): @@ -464,7 +467,11 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase): @slow def test_lm_generate_gptj(self): for checkpointing in [True, False]: - model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", gradient_checkpointing=checkpointing) + model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") + if checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() model.to(torch_device) input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog expected_output_ids = [