From 06e782da4e58f93a60c6bedc84b5991abaae58f5 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 25 Oct 2023 12:16:15 +0200 Subject: [PATCH] [`core`] Refactor of `gradient_checkpointing` (#27020) * v1 * fix * remove `create_custom_forward` * fixup * fixup * add test and fix all failing GC tests * remove all remaining `create_custom_forward` methods * fix idefics bug * fixup * replace with `__call__` * add comment * quality --- src/transformers/modeling_utils.py | 22 +++++++-- .../models/align/modeling_align.py | 20 ++++---- .../models/altclip/modeling_altclip.py | 33 +++++-------- .../modeling_audio_spectrogram_transformer.py | 17 +++---- .../models/autoformer/modeling_autoformer.py | 31 ++++-------- src/transformers/models/bark/modeling_bark.py | 19 +++----- src/transformers/models/bart/modeling_bart.py | 31 ++++-------- src/transformers/models/beit/modeling_beit.py | 17 +++---- src/transformers/models/bert/modeling_bert.py | 18 +++---- .../modeling_bert_generation.py | 18 +++---- .../models/big_bird/modeling_big_bird.py | 18 +++---- .../modeling_bigbird_pegasus.py | 31 ++++-------- .../models/biogpt/modeling_biogpt.py | 19 +++----- src/transformers/models/bit/modeling_bit.py | 5 +- .../models/blenderbot/modeling_blenderbot.py | 31 ++++-------- .../modeling_blenderbot_small.py | 31 ++++-------- src/transformers/models/blip/modeling_blip.py | 21 ++++---- .../models/blip/modeling_blip_text.py | 13 ++--- .../models/blip_2/modeling_blip_2.py | 34 +++++-------- .../models/bloom/modeling_bloom.py | 19 +++----- .../bridgetower/modeling_bridgetower.py | 13 ++--- src/transformers/models/bros/modeling_bros.py | 12 ++--- .../models/camembert/modeling_camembert.py | 18 +++---- .../models/canine/modeling_canine.py | 17 +++---- .../chinese_clip/modeling_chinese_clip.py | 30 ++++-------- src/transformers/models/clap/modeling_clap.py | 29 ++++------- src/transformers/models/clip/modeling_clip.py | 17 +++---- .../models/clipseg/modeling_clipseg.py | 17 +++---- .../models/codegen/modeling_codegen.py | 19 +++----- .../modeling_conditional_detr.py | 16 ++----- .../models/convbert/modeling_convbert.py | 17 +++---- .../models/convnext/modeling_convnext.py | 5 +- .../models/convnextv2/modeling_convnextv2.py | 5 +- .../models/cpmant/modeling_cpmant.py | 5 +- .../data2vec/modeling_data2vec_audio.py | 28 ++++------- .../models/data2vec/modeling_data2vec_text.py | 18 +++---- .../data2vec/modeling_data2vec_vision.py | 17 +++---- .../models/deberta/modeling_deberta.py | 17 +++---- .../models/deberta_v2/modeling_deberta_v2.py | 17 +++---- .../modeling_decision_transformer.py | 19 +++----- .../modeling_deformable_detr.py | 16 ++----- src/transformers/models/deit/modeling_deit.py | 17 +++---- .../models/deprecated/mctct/modeling_mctct.py | 17 +++---- .../open_llama/modeling_open_llama.py | 19 +++----- .../modeling_trajectory_transformer.py | 16 ++----- .../models/deprecated/van/modeling_van.py | 5 +- src/transformers/models/deta/modeling_deta.py | 16 ++----- src/transformers/models/detr/modeling_detr.py | 16 ++----- .../models/dinat/modeling_dinat.py | 2 +- .../models/dinov2/modeling_dinov2.py | 17 +++---- .../models/distilbert/modeling_distilbert.py | 17 +++---- .../models/donut/modeling_donut_swin.py | 16 ++----- src/transformers/models/dpr/modeling_dpr.py | 5 +- src/transformers/models/dpt/modeling_dpt.py | 17 +++---- .../efficientnet/modeling_efficientnet.py | 5 +- .../models/electra/modeling_electra.py | 18 +++---- .../models/encodec/modeling_encodec.py | 5 +- .../modeling_encoder_decoder.py | 6 +-- .../models/ernie/modeling_ernie.py | 18 +++---- .../models/ernie_m/modeling_ernie_m.py | 5 +- src/transformers/models/esm/modeling_esm.py | 18 +++---- .../models/falcon/modeling_falcon.py | 20 ++++---- .../models/flava/modeling_flava.py | 17 +++---- src/transformers/models/fnet/modeling_fnet.py | 14 ++---- .../models/focalnet/modeling_focalnet.py | 16 ++----- src/transformers/models/fuyu/modeling_fuyu.py | 5 +- src/transformers/models/git/modeling_git.py | 30 ++++-------- src/transformers/models/gpt2/modeling_gpt2.py | 20 +++----- .../gpt_bigcode/modeling_gpt_bigcode.py | 19 +++----- .../models/gpt_neo/modeling_gpt_neo.py | 19 +++----- .../models/gpt_neox/modeling_gpt_neox.py | 20 ++++---- .../modeling_gpt_neox_japanese.py | 5 +- src/transformers/models/gptj/modeling_gptj.py | 19 +++----- .../modeling_gptsan_japanese.py | 5 +- .../models/graphormer/modeling_graphormer.py | 5 +- .../models/groupvit/modeling_groupvit.py | 18 +++---- .../models/hubert/modeling_hubert.py | 42 +++++----------- .../models/idefics/modeling_idefics.py | 12 ++--- src/transformers/models/idefics/vision.py | 12 ++--- .../models/imagegpt/modeling_imagegpt.py | 19 +++----- .../models/informer/modeling_informer.py | 33 +++++-------- .../instructblip/modeling_instructblip.py | 34 +++++-------- .../models/layoutlm/modeling_layoutlm.py | 18 +++---- .../models/layoutlmv2/modeling_layoutlmv2.py | 17 +++---- .../models/layoutlmv3/modeling_layoutlmv3.py | 15 +----- src/transformers/models/led/modeling_led.py | 32 +++++-------- .../models/levit/modeling_levit.py | 5 +- src/transformers/models/lilt/modeling_lilt.py | 17 +++---- .../models/llama/modeling_llama.py | 23 +++++---- .../models/longformer/modeling_longformer.py | 18 +++---- .../models/longt5/modeling_longt5.py | 20 +++----- src/transformers/models/luke/modeling_luke.py | 17 +++---- .../models/m2m_100/modeling_m2m_100.py | 31 ++++-------- .../models/marian/modeling_marian.py | 31 ++++-------- .../models/markuplm/modeling_markuplm.py | 13 ++--- .../mask2former/modeling_mask2former.py | 12 ++--- .../models/maskformer/modeling_maskformer.py | 20 ++++---- .../maskformer/modeling_maskformer_swin.py | 19 ++++---- .../models/mbart/modeling_mbart.py | 33 +++++-------- .../megatron_bert/modeling_megatron_bert.py | 18 +++---- .../models/mgp_str/modeling_mgp_str.py | 5 +- .../models/mistral/modeling_mistral.py | 20 ++++---- .../models/mobilevit/modeling_mobilevit.py | 16 ++----- .../mobilevitv2/modeling_mobilevitv2.py | 16 ++----- src/transformers/models/mpt/modeling_mpt.py | 19 +++----- src/transformers/models/mra/modeling_mra.py | 16 ++----- src/transformers/models/mt5/modeling_mt5.py | 16 +++---- .../models/musicgen/modeling_musicgen.py | 23 ++++----- src/transformers/models/mvp/modeling_mvp.py | 31 ++++-------- src/transformers/models/nat/modeling_nat.py | 2 +- .../models/nezha/modeling_nezha.py | 18 +++---- .../models/nllb_moe/modeling_nllb_moe.py | 28 ++++------- .../nystromformer/modeling_nystromformer.py | 17 +++---- .../models/oneformer/modeling_oneformer.py | 2 +- src/transformers/models/opt/modeling_opt.py | 19 +++----- .../models/owlv2/modeling_owlv2.py | 18 +++---- .../models/owlvit/modeling_owlvit.py | 17 +++---- .../models/pegasus/modeling_pegasus.py | 31 ++++-------- .../models/pegasus_x/modeling_pegasus_x.py | 31 ++++-------- .../models/persimmon/modeling_persimmon.py | 19 +++----- .../models/pix2struct/modeling_pix2struct.py | 35 +++++--------- .../models/plbart/modeling_plbart.py | 31 ++++-------- .../models/poolformer/modeling_poolformer.py | 5 +- .../models/pop2piano/modeling_pop2piano.py | 16 +++---- .../models/prophetnet/modeling_prophetnet.py | 31 ++++-------- src/transformers/models/pvt/modeling_pvt.py | 5 +- .../models/qdqbert/modeling_qdqbert.py | 18 +++---- .../models/realm/modeling_realm.py | 13 ++--- .../models/regnet/modeling_regnet.py | 5 +- .../models/rembert/modeling_rembert.py | 18 +++---- .../models/resnet/modeling_resnet.py | 5 +- .../models/roberta/modeling_roberta.py | 18 +++---- .../modeling_roberta_prelayernorm.py | 18 +++---- .../models/roc_bert/modeling_roc_bert.py | 18 +++---- .../models/roformer/modeling_roformer.py | 18 +++---- src/transformers/models/rwkv/modeling_rwkv.py | 17 ++----- src/transformers/models/sam/modeling_sam.py | 11 +---- .../seamless_m4t/modeling_seamless_m4t.py | 42 +++++----------- src/transformers/models/sew/modeling_sew.py | 28 ++++------- .../models/sew_d/modeling_sew_d.py | 30 ++++-------- .../modeling_speech_encoder_decoder.py | 6 +-- .../speech_to_text/modeling_speech_to_text.py | 31 ++++-------- .../modeling_speech_to_text_2.py | 17 ++----- .../models/speecht5/modeling_speecht5.py | 48 +++++-------------- .../models/splinter/modeling_splinter.py | 18 +++---- .../swiftformer/modeling_swiftformer.py | 5 +- src/transformers/models/swin/modeling_swin.py | 16 ++----- .../models/swin/modeling_tf_swin.py | 5 -- .../models/swin2sr/modeling_swin2sr.py | 16 ++----- .../models/swinv2/modeling_swinv2.py | 16 ++----- .../modeling_switch_transformers.py | 16 +++---- src/transformers/models/t5/modeling_t5.py | 16 +++---- .../modeling_table_transformer.py | 16 ++----- .../models/tapas/modeling_tapas.py | 18 +++---- .../modeling_time_series_transformer.py | 31 ++++-------- .../timesformer/modeling_timesformer.py | 17 +++---- .../models/trocr/modeling_trocr.py | 19 +++----- src/transformers/models/tvlt/modeling_tvlt.py | 31 ++++-------- src/transformers/models/umt5/modeling_umt5.py | 16 +++---- .../models/unispeech/modeling_unispeech.py | 40 +++++----------- .../unispeech_sat/modeling_unispeech_sat.py | 40 +++++----------- .../models/upernet/modeling_upernet.py | 5 +- .../models/videomae/modeling_videomae.py | 31 ++++-------- src/transformers/models/vilt/modeling_vilt.py | 17 +++---- .../modeling_vision_encoder_decoder.py | 6 +-- .../visual_bert/modeling_visual_bert.py | 17 +++---- src/transformers/models/vit/modeling_vit.py | 17 +++---- .../models/vit_hybrid/modeling_vit_hybrid.py | 17 +++---- .../models/vit_mae/modeling_vit_mae.py | 31 ++++-------- .../models/vit_msn/modeling_vit_msn.py | 17 +++---- .../models/vitdet/modeling_vitdet.py | 17 +++---- .../models/vitmatte/modeling_vitmatte.py | 10 +++- src/transformers/models/vits/modeling_vits.py | 19 +++----- .../models/vivit/modeling_vivit.py | 17 +++---- .../models/wav2vec2/modeling_wav2vec2.py | 40 +++++----------- .../modeling_wav2vec2_conformer.py | 28 ++++------- .../models/wavlm/modeling_wavlm.py | 40 +++++----------- .../models/whisper/modeling_whisper.py | 31 ++++-------- .../models/x_clip/modeling_x_clip.py | 29 ++++------- src/transformers/models/xglm/modeling_xglm.py | 19 +++----- .../xlm_prophetnet/modeling_xlm_prophetnet.py | 31 ++++-------- .../xlm_roberta/modeling_xlm_roberta.py | 18 +++---- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 13 ++--- src/transformers/models/xmod/modeling_xmod.py | 18 +++---- .../models/yolos/modeling_yolos.py | 17 +++---- src/transformers/models/yoso/modeling_yoso.py | 17 +++---- ...ng_{{cookiecutter.lowercase_modelname}}.py | 47 +++++++----------- tests/test_modeling_common.py | 21 ++++++++ 188 files changed, 1276 insertions(+), 2296 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6f866f989a..47e9cb2f23 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +import functools import gc import importlib.metadata import inspect @@ -1848,16 +1849,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix self.base_model._prune_heads(heads_to_prune) - def gradient_checkpointing_enable(self): + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): """ Activates gradient checkpointing for the current model. Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint activations". + + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. """ if not self.supports_gradient_checkpointing: raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - self.apply(partial(self._set_gradient_checkpointing, value=True)) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + + gradient_checkpointing_func = functools.partial( + torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs + ) + + self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=gradient_checkpointing_func)) if getattr(self, "_hf_peft_config_loaded", False): # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True @@ -1874,7 +1890,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix activations". """ if self.supports_gradient_checkpointing: - self.apply(partial(self._set_gradient_checkpointing, value=False)) + self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None)) if getattr(self, "_hf_peft_config_loaded", False): self.disable_input_require_grads() diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 6cbf01a343..58dc2a8920 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1095,20 +1095,15 @@ class AlignTextEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1197,9 +1192,10 @@ class AlignPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (AlignTextModel, AlignVisionModel)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (AlignTextModel, AlignVisionModel, AlignTextEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index c4e32de55d..e6229165aa 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -646,20 +646,15 @@ class AltRobertaEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -960,18 +955,12 @@ class AltCLIPEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1089,11 +1078,13 @@ class AltCLIPPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, AltCLIPEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None if isinstance(module, AltRobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 28969f50b6..a1f85e2a09 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -336,17 +336,11 @@ class ASTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -395,9 +389,10 @@ class ASTPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) # Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST - def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ASTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ASTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 96298c77a3..40e3002310 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -946,9 +946,10 @@ class AutoformerPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (AutoformerDecoder, AutoformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None AUTOFORMER_START_DOCSTRING = r""" @@ -1207,18 +1208,12 @@ class AutoformerEncoder(AutoformerPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1425,16 +1420,8 @@ class AutoformerDecoder(AutoformerPreTrainedModel): "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1442,6 +1429,8 @@ class AutoformerDecoder(AutoformerPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 649719e0ee..2708b00d05 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -313,9 +313,10 @@ class BarkPreTrainedModel(PreTrainedModel): return get_parameter_device(self) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BARK_MODEL_START_DOCSTRING = """ @@ -637,20 +638,14 @@ class BarkCausalModel(BarkPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 9e7763ca23..73eca72e5d 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -521,9 +521,10 @@ class BartPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BartDecoder, BartEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -854,18 +855,12 @@ class BartEncoder(BartPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1110,16 +1105,8 @@ class BartDecoder(BartPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1127,6 +1114,8 @@ class BartDecoder(BartPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index d698cff88b..3ba3d4911b 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -510,17 +510,11 @@ class BeitEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: relative_position_bias = ( @@ -572,9 +566,10 @@ class BeitPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BeitEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BEIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 1b0fad3f9d..91380e13a0 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -593,20 +593,15 @@ class BertEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -762,9 +757,10 @@ class BertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index abe2d828b2..123cb2212e 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -401,20 +401,15 @@ class BertEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -607,9 +602,10 @@ class BertGenerationPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BERT_GENERATION_START_DOCSTRING = r""" diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index e266b1a67b..0ba2119e68 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1617,15 +1617,8 @@ class BigBirdEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, @@ -1635,6 +1628,8 @@ class BigBirdEncoder(nn.Module): from_mask, to_mask, blocked_encoder_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1784,9 +1779,10 @@ class BigBirdPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BigBirdEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIG_BIRD_START_DOCSTRING = r""" diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 4e279f9dc0..98ff51032b 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1609,9 +1609,10 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -1943,15 +1944,8 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1960,6 +1954,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): to_mask, blocked_encoder_mask, blocked_encoder_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2289,16 +2284,8 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -2306,6 +2293,8 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index ca084db5c7..2bbdbed348 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -376,9 +376,10 @@ class BioGptPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BioGptModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIOGPT_START_DOCSTRING = r""" @@ -590,20 +591,14 @@ class BioGptModel(BioGptPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index 12a5ecd42b..d02861d634 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -669,9 +669,10 @@ class BitPreTrainedModel(PreTrainedModel): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BitModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 1db8190521..51a947af0a 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -483,9 +483,10 @@ class BlenderbotPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -777,18 +778,12 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1032,16 +1027,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1049,6 +1036,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 129de3dd14..88a9b52de9 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -480,9 +480,10 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -775,18 +776,12 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1029,16 +1024,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1046,6 +1033,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 9fca7c28a1..efd986299c 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -34,7 +34,7 @@ from ...utils import ( replace_return_docstrings, ) from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig -from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel +from .modeling_blip_text import BlipTextEncoder, BlipTextLMHeadModel, BlipTextModel logger = logging.get_logger(__name__) @@ -461,9 +461,10 @@ class BlipPreTrainedModel(PreTrainedModel): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, BlipEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (BlipEncoder, BlipTextEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BLIP_START_DOCSTRING = r""" @@ -622,17 +623,11 @@ class BlipEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 49b958afc2..e0aa4e17f1 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -422,20 +422,15 @@ class BlipTextEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index bd56b17e55..2f7f00b3dd 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -297,9 +297,14 @@ class Blip2PreTrainedModel(PreTrainedModel): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, Blip2Encoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (Blip2Encoder, Blip2QFormerEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None + + # Enable / disable GC for the language model as well + if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): + self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) BLIP_2_START_DOCSTRING = r""" @@ -473,17 +478,11 @@ class Blip2Encoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -944,15 +943,8 @@ class Blip2QFormerEncoder(nn.Module): "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, query_length) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index d90bb6ad8f..583367c9ab 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -496,9 +496,10 @@ class BloomPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, BloomModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_to_standard_cache( @@ -761,21 +762,15 @@ class BloomModel(BloomPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index ce569157b8..0f272a21e2 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -804,20 +804,15 @@ class BridgeTowerTextEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index a8ea8d4919..c10f835056 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -651,21 +651,15 @@ class BrosEncoder(nn.Module): "`use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, bbox_pos_emb, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 8d7d279579..2e0a6c12fe 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -524,20 +524,15 @@ class CamembertEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -625,9 +620,10 @@ class CamembertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CamembertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CAMEMBERT_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 657104ad69..adc8759103 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -795,18 +795,12 @@ class CanineEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -919,9 +913,10 @@ class CaninePreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CanineEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CANINE_START_DOCSTRING = r""" diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 7bab0aea6e..ef1c265723 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -742,9 +742,10 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ChineseCLIPVisionEncoder) or isinstance(module, ChineseCLIPTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CHINESE_CLIP_START_DOCSTRING = r""" @@ -909,20 +910,15 @@ class ChineseCLIPTextEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1018,16 +1014,10 @@ class ChineseCLIPVisionEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 1d17a51883..025b59ae4b 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -939,15 +939,8 @@ class ClapAudioEncoder(nn.Module): input_dimensions = self.input_resolutions[i] if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -1595,20 +1588,15 @@ class ClapTextEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1701,9 +1689,10 @@ class ClapPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ClapTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class ClapAudioModel(ClapPreTrainedModel): diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 3a894b9727..56f24c157f 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -467,9 +467,10 @@ class CLIPPreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CLIPEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CLIP_START_DOCSTRING = r""" @@ -639,18 +640,12 @@ class CLIPEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 96f13217aa..7a0e529269 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -479,9 +479,10 @@ class CLIPSegPreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CLIPSegEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CLIPSEG_START_DOCSTRING = r""" @@ -648,18 +649,12 @@ class CLIPSegEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 05699ef15c..9a5509a9ed 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -339,9 +339,10 @@ class CodeGenPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CodeGenModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CODEGEN_START_DOCSTRING = r""" @@ -542,21 +543,15 @@ class CodeGenModel(CodeGenPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, position_ids, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 69937afefc..01dbf8ecd5 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -1171,9 +1171,10 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConditionalDetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONDITIONAL_DETR_START_DOCSTRING = r""" @@ -1518,15 +1519,8 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel): # apply transformation query_sine_embed = query_sine_embed_before_transformation * pos_transformation if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, object_queries, diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index a6fccf5b72..da577a5896 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -264,9 +264,10 @@ class ConvBertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class SeparableConv1D(nn.Module): @@ -632,20 +633,14 @@ class ConvBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index e6cf336517..e11112b532 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -296,9 +296,10 @@ class ConvNextPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvNextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONVNEXT_START_DOCSTRING = r""" diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index 3a268c713d..f1ff89bb12 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -317,9 +317,10 @@ class ConvNextV2PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvNextV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONVNEXTV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index 6d2dc596fa..8a6c744ed6 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -556,9 +556,10 @@ class CpmAntPreTrainedModel(PreTrainedModel): elif isinstance(module, CpmAntSegmentPositionEmbedding): module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CpmAntEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CPMANT_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 4435e9b8d0..a99b6f3a6d 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -293,15 +293,8 @@ class Data2VecAudioFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -593,17 +586,11 @@ class Data2VecAudioEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -761,9 +748,10 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Data2VecAudioEncoder, Data2VecAudioFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VEC_AUDIO_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index a521ccb39a..507c2fc464 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -510,20 +510,15 @@ class Data2VecTextEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -613,9 +608,10 @@ class Data2VecTextPreTrainedModel(PreTrainedModel): if hasattr(module, "weight") and module.weight is not None: module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Data2VecTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VECTEXT_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index f8fe59587a..2742d5ffc3 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -522,17 +522,11 @@ class Data2VecVisionEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: relative_position_bias = ( @@ -585,9 +579,10 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Data2VecVisionEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VEC_VISION_START_DOCSTRING = r""" diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 6f6c2af63a..65ec497cec 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -457,20 +457,14 @@ class DebertaEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + hidden_states = self.gradient_checkpointing_func( + layer_module.__call__, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: hidden_states = layer_module( @@ -839,9 +833,10 @@ class DebertaPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DebertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index eda4f406cb..2245ac549a 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -501,20 +501,14 @@ class DebertaV2Encoder(nn.Module): all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - output_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + output_states = self.gradient_checkpointing_func( + layer_module.__call__, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: output_states = layer_module( @@ -938,9 +932,10 @@ class DebertaV2PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DebertaV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 8e5053a416..19c2731a50 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -469,9 +469,10 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DecisionTransformerGPT2Model): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): @@ -631,22 +632,16 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index ea4555d5ae..220fcf0d06 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1088,9 +1088,10 @@ class DeformableDetrPreTrainedModel(PreTrainedModel): if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DeformableDetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEFORMABLE_DETR_START_DOCSTRING = r""" @@ -1383,15 +1384,8 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel): all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 38c28dbbed..6e97e932b5 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -357,17 +357,11 @@ class DeiTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -415,9 +409,10 @@ class DeiTPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: DeiTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: DeiTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, DeiTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index eca5ba014e..9e7a73c588 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -504,9 +504,10 @@ class MCTCTPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MCTCTEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MCTCT_START_DOCSTRING = r""" @@ -616,18 +617,12 @@ class MCTCTEncoder(MCTCTPreTrainedModel): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 3ace323e82..fb1cc7f0fb 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -456,9 +456,10 @@ class OpenLlamaPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, OpenLlamaModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OPEN_LLAMA_INPUTS_DOCSTRING = r""" @@ -665,20 +666,14 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, None, + output_attentions, + None, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index 75415dbe77..c9f31c7144 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -163,9 +163,10 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel): main_input_name = "trajectories" supports_gradient_checkpointing = True - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TrajectoryTransformerModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): @@ -550,15 +551,8 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, layer_past, use_cache, diff --git a/src/transformers/models/deprecated/van/modeling_van.py b/src/transformers/models/deprecated/van/modeling_van.py index 4ef18f5415..52c9e12424 100644 --- a/src/transformers/models/deprecated/van/modeling_van.py +++ b/src/transformers/models/deprecated/van/modeling_van.py @@ -387,9 +387,10 @@ class VanPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VanModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VAN_START_DOCSTRING = r""" diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 1aab38c289..a6f979eaee 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -979,9 +979,10 @@ class DetaPreTrainedModel(PreTrainedModel): if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DetaDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DETA_START_DOCSTRING = r""" @@ -1275,15 +1276,8 @@ class DetaDecoder(DetaPreTrainedModel): all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index d2b6ea07d7..1c09e3e3d7 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -927,9 +927,10 @@ class DetrPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DETR_START_DOCSTRING = r""" @@ -1253,15 +1254,8 @@ class DetrDecoder(DetrPreTrainedModel): continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 89c6ed2e2a..eb4d3f2ff2 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -660,7 +660,7 @@ class DinatPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: DinatEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: DinatEncoder, gradient_checkpointing_func=None) -> None: pass diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 6e4446fadd..1440b6d615 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -447,17 +447,11 @@ class Dinov2Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -516,9 +510,10 @@ class Dinov2PreTrainedModel(PreTrainedModel): std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: Dinov2Encoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: Dinov2Encoder, gradient_checkpointing_func=None) -> None: if isinstance(module, Dinov2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DINOV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index f26b584697..3768dd6e91 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -358,18 +358,12 @@ class Transformer(nn.Module): all_hidden_states = all_hidden_states + (hidden_state,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_state, attn_mask, head_mask[i], + output_attentions, ) else: layer_outputs = layer_module( @@ -430,9 +424,10 @@ class DistilBertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Transformer): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DISTILBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 0d833406e2..76d525717f 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -749,15 +749,8 @@ class DonutSwinEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -826,9 +819,10 @@ class DonutSwinPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DonutSwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index 944ce142b0..c258343f6c 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -164,9 +164,10 @@ class DPRPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class DPREncoder(DPRPreTrainedModel): diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 187a6c3665..2621fa3380 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -528,17 +528,11 @@ class DPTViTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -818,9 +812,10 @@ class DPTPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DPTViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DPT_START_DOCSTRING = r""" diff --git a/src/transformers/models/efficientnet/modeling_efficientnet.py b/src/transformers/models/efficientnet/modeling_efficientnet.py index 478aeecee0..d1b2c99403 100644 --- a/src/transformers/models/efficientnet/modeling_efficientnet.py +++ b/src/transformers/models/efficientnet/modeling_efficientnet.py @@ -500,9 +500,10 @@ class EfficientNetPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, EfficientNetBlock): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index da3ee8e51d..fde5632c09 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -571,20 +571,15 @@ class ElectraEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -692,9 +687,10 @@ class ElectraPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ElectraEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index 697fb3c94f..28c20da3d5 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -473,9 +473,10 @@ class EncodecPreTrainedModel(PreTrainedModel): elif "bias" in name: nn.init.constant_(param, 0.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (EncodecEncoder, EncodecDecoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ENCODEC_START_DOCSTRING = r""" diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 787db72726..a13fd19a90 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -265,10 +265,10 @@ class EncoderDecoderModel(PreTrainedModel): self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index d55155f800..330cb50331 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -506,20 +506,15 @@ class ErnieEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -680,9 +675,10 @@ class ErniePreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ErnieEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/ernie_m/modeling_ernie_m.py b/src/transformers/models/ernie_m/modeling_ernie_m.py index 9c53ddd73c..b26ee0fcaf 100755 --- a/src/transformers/models/ernie_m/modeling_ernie_m.py +++ b/src/transformers/models/ernie_m/modeling_ernie_m.py @@ -429,9 +429,10 @@ class ErnieMPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ErnieMEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ERNIE_M_START_DOCSTRING = r""" diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 7a07495ba7..86bd20a464 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -605,20 +605,15 @@ class EsmEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -710,9 +705,10 @@ class EsmPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, EsmEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ESM_START_DOCSTRING = r""" diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 6eaeed4199..642e60a72f 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1097,9 +1097,10 @@ class FalconPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, FalconModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_cache_to_standard_format( @@ -1278,21 +1279,16 @@ class FalconModel(FalconPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, alibi, attention_mask, position_ids, head_mask[i], + layer_past, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 8de647c829..1fbf49f9e1 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -663,18 +663,12 @@ class FlavaEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -879,9 +873,10 @@ class FlavaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: FlavaEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: FlavaEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, FlavaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 4504214776..b84761536b 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -292,14 +292,7 @@ class FNetEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module), hidden_states) + layer_outputs = self.gradient_checkpointing_func(layer_module.__call__, hidden_states) else: layer_outputs = layer_module(hidden_states) @@ -431,9 +424,10 @@ class FNetPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 8d18a8c63f..87ec981696 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -586,15 +586,8 @@ class FocalNetEncoder(nn.Module): for i, stage_module in enumerate(self.stages): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - stage_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(stage_module), + stage_outputs = self.gradient_checkpointing_func( + stage_module.__call__, hidden_states, input_dimensions, ) @@ -659,9 +652,10 @@ class FocalNetPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FocalNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None FOCALNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 03312420ca..37f9890ee3 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -70,9 +70,10 @@ class FuyuPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FuyuForCausalLM): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None FUYU_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 00707e42dd..293b9c789d 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -452,18 +452,13 @@ class GitEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -533,9 +528,10 @@ class GitPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GitEncoder, GitVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GIT_START_DOCSTRING = r""" @@ -878,18 +874,12 @@ class GitVisionEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 838e7ca299..24826a76bc 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -480,9 +480,10 @@ class GPT2PreTrainedModel(PreTrainedModel): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPT2Model): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass @@ -877,22 +878,16 @@ class GPT2Model(GPT2PreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( @@ -1623,7 +1618,6 @@ class GPT2ForQuestionAnswering(GPT2PreTrainedModel): # Model parallel self.model_parallel = False self.device_map = None - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index be90f61e45..37c51b40c9 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -405,9 +405,10 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTBigCodeModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPT_BIGCODE_START_DOCSTRING = r""" @@ -650,22 +651,16 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 3ad49554c0..ed1e62bf17 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -384,9 +384,10 @@ class GPTNeoPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPT_NEO_START_DOCSTRING = r""" @@ -604,20 +605,14 @@ class GPTNeoModel(GPTNeoPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 9391805a77..cf0aa0645a 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -78,9 +78,10 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoXModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GPTNeoXAttention(nn.Module): @@ -641,20 +642,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for layer_past - return module(*inputs, use_cache, None, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, position_ids, head_mask[i], + use_cache, + None, + output_attentions, ) else: outputs = layer( diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 98753edeb5..c1c5527a46 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -66,9 +66,10 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoXJapaneseModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GPTNeoXJapaneseAttention(nn.Module): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index acdbb8c492..2910f9535f 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -363,9 +363,10 @@ class GPTJPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTJModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPTJ_START_DOCSTRING = r""" @@ -669,21 +670,15 @@ class GPTJModel(GPTJPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, position_ids, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py index 24917fcfdb..84d956c9f5 100644 --- a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py @@ -759,9 +759,10 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel): module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GPTSanJapaneseAttention,)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right def _shift_right(self, input_ids): diff --git a/src/transformers/models/graphormer/modeling_graphormer.py b/src/transformers/models/graphormer/modeling_graphormer.py index 8247745a3b..68ed6d265e 100755 --- a/src/transformers/models/graphormer/modeling_graphormer.py +++ b/src/transformers/models/graphormer/modeling_graphormer.py @@ -772,9 +772,10 @@ class GraphormerPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GraphormerModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GraphormerModel(GraphormerPreTrainedModel): diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 59ff60ed76..a9de671438 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -492,7 +492,6 @@ class GroupViTStage(nn.Module): self.group_token = nn.Parameter(torch.zeros(1, num_group_token, config.hidden_size)) else: self.group_token = None - self.gradient_checkpointing = False self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(depth)]) if num_group_token > 0: @@ -805,9 +804,10 @@ class GroupViTPreTrainedModel(PreTrainedModel): nn.init.normal_(module.fc1.weight, std=fc_std) nn.init.normal_(module.fc2.weight, std=in_proj_std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GroupViTTextEncoder, GroupViTVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GROUPVIT_START_DOCSTRING = r""" @@ -1031,18 +1031,12 @@ class GroupViTTextEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 1a7bde45ef..732e6be2f8 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -346,15 +346,8 @@ class HubertFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -731,17 +724,11 @@ class HubertEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -821,17 +808,11 @@ class HubertEncoderStableLayerNorm(nn.Module): # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -895,9 +876,10 @@ class HubertPreTrainedModel(PreTrainedModel): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (HubertFeatureEncoder, HubertEncoder, HubertEncoderStableLayerNorm)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 316f365613..28841903a1 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -40,7 +40,7 @@ from ...utils import ( ) from .configuration_idefics import IdeficsConfig from .perceiver import IdeficsPerceiverResampler -from .vision import IdeficsVisionTransformer +from .vision import IdeficsVisionEncoder, IdeficsVisionTransformer logger = logging.get_logger(__name__) @@ -978,9 +978,10 @@ class IdeficsPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, IdeficsModel): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (IdeficsModel, IdeficsVisionEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LLAMA_INPUTS_DOCSTRING = r""" @@ -1098,7 +1099,6 @@ class IdeficsModel(IdeficsPreTrainedModel): self.norm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1339,7 +1339,7 @@ class IdeficsModel(IdeficsPreTrainedModel): ) use_cache = False - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( vblock, decoder_layer, hidden_states, diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index d4966a240d..24dc3e9396 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -401,18 +401,12 @@ class IdeficsVisionEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 54edcd30fc..a365731ed5 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -525,9 +525,10 @@ class ImageGPTPreTrainedModel(PreTrainedModel): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ImageGPTModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None IMAGEGPT_START_DOCSTRING = r""" @@ -816,22 +817,16 @@ class ImageGPTModel(ImageGPTPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index e7b35174ca..53518760cc 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -924,9 +924,10 @@ class InformerPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (InformerDecoder, InformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None INFORMER_START_DOCSTRING = r""" @@ -1215,21 +1216,15 @@ class InformerEncoder(InformerPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) if conv_layer is not None: - output = torch.utils.checkpoint.checkpoint(conv_layer, layer_outputs[0]) + output = self.gradient_checkpointing_func(conv_layer, layer_outputs[0]) layer_outputs = (output,) + layer_outputs[1:] else: layer_outputs = encoder_layer( @@ -1438,16 +1433,8 @@ class InformerDecoder(InformerPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1455,6 +1442,8 @@ class InformerDecoder(InformerPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 082900a665..d4cb7a1fa0 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -304,9 +304,14 @@ class InstructBlipPreTrainedModel(PreTrainedModel): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, InstructBlipEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (InstructBlipEncoder, InstructBlipQFormerEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None + + # Enable / disable GC for the language model as well + if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): + self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) INSTRUCTBLIP_START_DOCSTRING = r""" @@ -462,17 +467,11 @@ class InstructBlipEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -939,15 +938,8 @@ class InstructBlipQFormerEncoder(nn.Module): "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, query_length) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 884a279972..ce6d4302bc 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -487,20 +487,15 @@ class LayoutLMEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -638,9 +633,10 @@ class LayoutLMPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LayoutLMEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LAYOUTLM_START_DOCSTRING = r""" diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index ef970edfdc..8f6260fdda 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -439,18 +439,12 @@ class LayoutLMv2Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, rel_pos=rel_pos, rel_2d_pos=rel_2d_pos, ) @@ -514,9 +508,10 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LayoutLMv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def my_convert_sync_batchnorm(module, process_group=None): diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 30ab0a5e86..e387707e52 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -657,19 +657,8 @@ class LayoutLMv3Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - # return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos) - # The above line will cause error: - # RuntimeError: Trying to backward through the graph a second time - # (or directly access saved tensors after they have already been freed). - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index f0c22ed950..61bbd4156b 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1155,9 +1155,10 @@ class LEDPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (LEDDecoder, LEDEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -1876,20 +1877,15 @@ class LEDEncoder(LEDPreTrainedModel): layer_outputs = (None, None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, is_global_attn, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, is_index_masked, is_index_global_attn, + is_global_attn, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2142,16 +2138,8 @@ class LEDDecoder(LEDPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, @@ -2159,6 +2147,8 @@ class LEDDecoder(LEDPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/levit/modeling_levit.py b/src/transformers/models/levit/modeling_levit.py index 0accc28391..5acaaeba90 100644 --- a/src/transformers/models/levit/modeling_levit.py +++ b/src/transformers/models/levit/modeling_levit.py @@ -507,9 +507,10 @@ class LevitPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LevitModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LEVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 46fe2d3e9c..4fd7a85aff 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -514,19 +514,13 @@ class LiltEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layout_inputs, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module( @@ -607,9 +601,10 @@ class LiltPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LiltEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LILT_START_DOCSTRING = r""" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 541455d86a..279884dc16 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -827,9 +827,10 @@ class LlamaPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LlamaModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LLAMA_INPUTS_DOCSTRING = r""" @@ -1013,16 +1014,14 @@ class LlamaModel(LlamaPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 33bf9a6f92..b4f20b4525 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1304,20 +1304,15 @@ class LongformerEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, is_global_attn, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, is_index_masked, is_index_global_attn, + is_global_attn, + output_attentions, ) else: layer_outputs = layer_module( @@ -1439,9 +1434,10 @@ class LongformerPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LongformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LONGFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index c80d234983..9abbfa2f20 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -775,7 +775,6 @@ class LongT5TransientGlobalAttention(nn.Module): if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.pruned_heads = set() - self.gradient_checkpointing = False # Relativen attention bias & Layer norm for global attention if self.has_relative_attention_bias: @@ -1340,10 +1339,10 @@ class LongT5PreTrainedModel(PreTrainedModel): mean=0.0, std=factor * ((d_model) ** -0.5) ) - # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5 - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (LongT5Attention, LongT5Stack)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (LongT5Attention, LongT5Stack, LongT5LocalAttention)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5 def _shift_right(self, input_ids): @@ -1510,15 +1509,8 @@ class LongT5Stack(LongT5PreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1528,6 +1520,8 @@ class LongT5Stack(LongT5PreTrainedModel): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 6913ede09d..3b5f4d0bf7 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -788,19 +788,13 @@ class LukeEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, word_hidden_states, entity_hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module( @@ -920,9 +914,10 @@ class LukePreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LukeEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LUKE_START_DOCSTRING = r""" diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 6db8bbb521..4ebe11f3f3 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -552,9 +552,10 @@ class M2M100PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (M2M100Decoder, M2M100Encoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None M2M_100_START_DOCSTRING = r""" @@ -820,18 +821,12 @@ class M2M100Encoder(M2M100PreTrainedModel): # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1066,16 +1061,8 @@ class M2M100Decoder(M2M100PreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, @@ -1083,6 +1070,8 @@ class M2M100Decoder(M2M100PreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 69de5b2e7d..e2e09b564b 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -500,9 +500,10 @@ class MarianPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MarianDecoder, MarianEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -788,18 +789,12 @@ class MarianEncoder(MarianPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1037,16 +1032,8 @@ class MarianDecoder(MarianPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1054,6 +1041,8 @@ class MarianDecoder(MarianPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 530c66a0c8..80498efb3c 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -648,20 +648,15 @@ class MarkupLMEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index e839b16f62..86eccc4787 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1864,20 +1864,14 @@ class Mask2FormerMaskedAttentionDecoder(nn.Module): continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, None, None, + output_attentions, ) else: diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 87b91ed64b..7df8b60792 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -848,20 +848,14 @@ class DetrDecoder(nn.Module): continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, encoder_attention_mask, None, + output_attentions, ) else: layer_outputs = decoder_layer( @@ -1619,11 +1613,13 @@ class MaskFormerPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MaskFormerPixelLevelModule): - module.encoder.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.encoder.gradient_checkpointing = gradient_checkpointing_func is not None if isinstance(module, DetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 357ac9d4aa..89c6a0c0e0 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -688,15 +688,11 @@ class MaskFormerSwinEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, layer_head_mask + layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, ) else: layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module( @@ -752,9 +748,10 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MaskFormerSwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index b53ad8848d..7c4c9bdf95 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -516,9 +516,10 @@ class MBartPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (MBartDecoder, MBartDecoder)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (MBartDecoder, MBartEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -828,18 +829,12 @@ class MBartEncoder(MBartPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1086,16 +1081,8 @@ class MBartDecoder(MBartPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1103,6 +1090,8 @@ class MBartDecoder(MBartPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 5d0ad6e341..c23666f10b 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -551,20 +551,15 @@ class MegatronBertEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -728,9 +723,10 @@ class MegatronBertPreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MegatronBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/mgp_str/modeling_mgp_str.py b/src/transformers/models/mgp_str/modeling_mgp_str.py index 5d1f5bea7b..1257b4df39 100644 --- a/src/transformers/models/mgp_str/modeling_mgp_str.py +++ b/src/transformers/models/mgp_str/modeling_mgp_str.py @@ -333,9 +333,10 @@ class MgpstrPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: MgpstrEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: MgpstrEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, MgpstrEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MGP_STR_START_DOCSTRING = r""" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 53667b6a82..fbedb20dbc 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -816,9 +816,10 @@ class MistralPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MistralModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MISTRAL_INPUTS_DOCSTRING = r""" @@ -1020,19 +1021,14 @@ class MistralModel(MistralPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, + past_key_value, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index c3accb21e0..c664c02a88 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -626,15 +626,8 @@ class MobileViTEncoder(nn.Module): for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + hidden_states = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, ) else: @@ -672,9 +665,10 @@ class MobileViTPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MobileViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MOBILEVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index 5a0e08d734..b88925f41b 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -582,15 +582,8 @@ class MobileViTV2Encoder(nn.Module): for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + hidden_states = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, ) else: @@ -629,9 +622,10 @@ class MobileViTV2PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MobileViTV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MOBILEVITV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index d760bec985..ede306e71b 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -294,9 +294,10 @@ class MptPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, MptModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_to_mpt_cache( @@ -523,20 +524,14 @@ class MptModel(MptPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index d400fea6d2..f6cb65889a 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -766,15 +766,8 @@ class MraEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, ) @@ -871,9 +864,10 @@ class MraPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MraEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MRA_START_DOCSTRING = r""" diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 186db94dad..2951ffc889 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -845,9 +845,10 @@ class MT5PreTrainedModel(PreTrainedModel): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MT5Attention, MT5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1073,15 +1074,8 @@ class MT5Stack(MT5PreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1091,6 +1085,8 @@ class MT5Stack(MT5PreTrainedModel): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index bcc6bc82a2..a740ed4707 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -475,9 +475,10 @@ class MusicgenPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MusicgenDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MUSICGEN_START_DOCSTRING = r""" @@ -826,16 +827,8 @@ class MusicgenDecoder(MusicgenPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, @@ -843,6 +836,8 @@ class MusicgenDecoder(MusicgenPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( @@ -1562,10 +1557,10 @@ class MusicgenForConditionalGeneration(PreTrainedModel): self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.text_encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.text_encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_audio_encoder(self): return self.audio_encoder diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 5c1ed05249..122b492878 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -563,9 +563,10 @@ class MvpPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MvpDecoder, MvpEncoder, MvpPrompt)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -949,19 +950,13 @@ class MvpEncoder(MvpPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), (self_attn_prompt[idx] if self.use_prompt else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1227,16 +1222,8 @@ class MvpDecoder(MvpPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1246,6 +1233,8 @@ class MvpDecoder(MvpPreTrainedModel): self_attn_prompt[idx] if self.use_prompt else None, cross_attn_prompt[idx] if self.use_prompt else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/nat/modeling_nat.py b/src/transformers/models/nat/modeling_nat.py index ecc745b558..4f7206a5e8 100644 --- a/src/transformers/models/nat/modeling_nat.py +++ b/src/transformers/models/nat/modeling_nat.py @@ -639,7 +639,7 @@ class NatPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: NatEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: NatEncoder, gradient_checkpointing_func=None) -> None: pass diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index fa31e94f4d..cd43688e3f 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -577,20 +577,15 @@ class NezhaEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -752,9 +747,10 @@ class NezhaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, NezhaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index a88d53a340..cbed1e1b15 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -874,9 +874,10 @@ class NllbMoePreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (NllbMoeDecoder, NllbMoeEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None NLLB_MOE_START_DOCSTRING = r""" @@ -1153,18 +1154,12 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): layer_outputs = (None, None, None) else: if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1426,15 +1421,8 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, @@ -1442,6 +1430,8 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 51ee73ab72..9b2052eb6c 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -370,17 +370,11 @@ class NystromformerEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) @@ -477,9 +471,10 @@ class NystromformerPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, NystromformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None NYSTROMFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 5b6220f881..165684542d 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -2616,7 +2616,7 @@ class OneFormerTextTransformer(nn.Module): def forward(self, hidden_states: torch.Tensor): for layer in self.layers: if self.use_checkpoint: - hidden_states = torch.utils.checkpoint.checkpoint(layer, hidden_states) + hidden_states = self.gradient_checkpointing_func(layer, hidden_states) else: hidden_states = layer(hidden_states) return hidden_states diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 8f3f246524..9925e7b4a4 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -411,9 +411,10 @@ class OPTPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (OPTDecoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OPT_INPUTS_DOCSTRING = r""" @@ -691,20 +692,14 @@ class OPTDecoder(OPTPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, causal_attention_mask, head_mask[idx] if head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 451cc4a691..a1491d15ea 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -584,9 +584,10 @@ class Owlv2PreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Owlv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OWLV2_START_DOCSTRING = r""" @@ -764,18 +765,12 @@ class Owlv2Encoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1378,6 +1373,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel): def objectness_predictor(self, image_features: torch.FloatTensor) -> torch.FloatTensor: """Predicts the probability that each image feature token is an object. + Args: image_features (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_dim)`)): Features extracted from the image. diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 66cfb8092d..68037d1395 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -576,9 +576,10 @@ class OwlViTPreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, OwlViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OWLVIT_START_DOCSTRING = r""" @@ -753,18 +754,12 @@ class OwlViTEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 55856f7b06..058ecd1775 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -500,9 +500,10 @@ class PegasusPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PegasusDecoder, PegasusEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PEGASUS_START_DOCSTRING = r""" @@ -803,18 +804,12 @@ class PegasusEncoder(PegasusPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1087,16 +1082,8 @@ class PegasusDecoder(PegasusPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1104,6 +1091,8 @@ class PegasusDecoder(PegasusPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index e87e9c7164..6eaddf642a 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -780,9 +780,10 @@ class PegasusXPreTrainedModel(PreTrainedModel): elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PegasusXDecoder, PegasusXEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PEGASUS_X_START_DOCSTRING = r""" @@ -1071,18 +1072,12 @@ class PegasusXEncoder(PegasusXPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, global_hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1331,21 +1326,15 @@ class PegasusXDecoder(PegasusXPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index d73cc44844..8043fc8699 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -467,9 +467,10 @@ class PersimmonPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, PersimmonModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PERSIMMON_INPUTS_DOCSTRING = r""" @@ -668,19 +669,13 @@ class PersimmonModel(PersimmonPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, + past_key_value, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 58041820c1..cfc2b137c5 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -343,18 +343,12 @@ class Pix2StructVisionEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -563,9 +557,10 @@ class Pix2StructVisionModel(Pix2StructPreTrainedModel): # Initialize weights and apply final processing self.post_init() - def _set_gradient_checkpointing(self, module: Pix2StructVisionEncoder, value: bool = False) -> None: - if isinstance(module, Pix2StructVisionEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module: Pix2StructVisionEncoder, gradient_checkpointing_func=None) -> None: + if isinstance(module, (Pix2StructVisionEncoder, Pix2StructVisionAttention)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def get_input_embeddings(self): return self.embeddings.patch_projection @@ -1320,9 +1315,10 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] supports_gradient_checkpointing = True - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Pix2StructTextAttention, Pix2StructTextModel)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def __init__(self, config): super().__init__(config) @@ -1495,15 +1491,8 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1513,6 +1502,8 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 3a88083923..1e047fd372 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -517,9 +517,10 @@ class PLBartPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PLBartDecoder, PLBartEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PLBART_START_DOCSTRING = r""" @@ -807,18 +808,12 @@ class PLBartEncoder(PLBartPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1064,16 +1059,8 @@ class PLBartDecoder(PLBartPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1081,6 +1068,8 @@ class PLBartDecoder(PLBartPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 6acc8ec98e..209533e319 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -282,9 +282,10 @@ class PoolFormerPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, PoolFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None POOLFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 5a67b8044b..5cf7039e9f 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -739,9 +739,10 @@ class Pop2PianoPreTrainedModel(PreTrainedModel): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Pop2PianoAttention, Pop2PianoStack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -902,15 +903,8 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -920,6 +914,8 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 241a9efea3..e4c28659cb 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -557,9 +557,10 @@ class ProphetNetPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (ProphetNetDecoder, ProphetNetEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1329,18 +1330,12 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): encoder_hidden_states = encoder_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1569,16 +1564,8 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, extended_attention_mask, encoder_hidden_states, @@ -1590,6 +1577,8 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): predict_relative_position_buckets, position_ids, None, + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py index 2dd452ec1d..356b7c14af 100755 --- a/src/transformers/models/pvt/modeling_pvt.py +++ b/src/transformers/models/pvt/modeling_pvt.py @@ -489,9 +489,10 @@ class PvtPreTrainedModel(PreTrainedModel): std=self.config.initializer_range, ) - def _set_gradient_checkpointing(self, module: PvtEncoder, value: bool = False): + def _set_gradient_checkpointing(self, module: PvtEncoder, gradient_checkpointing_func=None): if isinstance(module, PvtEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PVT_START_DOCSTRING = r""" diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index fead8fc0cf..0a2546a9b6 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -581,20 +581,15 @@ class QDQBertEncoder(nn.Module): "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -757,9 +752,10 @@ class QDQBertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, QDQBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None QDQBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index aa738d782b..86b37b2156 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -586,20 +586,15 @@ class RealmEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/regnet/modeling_regnet.py b/src/transformers/models/regnet/modeling_regnet.py index 07ef29fd33..21050f07fd 100644 --- a/src/transformers/models/regnet/modeling_regnet.py +++ b/src/transformers/models/regnet/modeling_regnet.py @@ -293,9 +293,10 @@ class RegNetPreTrainedModel(PreTrainedModel): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RegNetModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None REGNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 235bff89f8..e5e662a9b5 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -543,20 +543,15 @@ class RemBertEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -673,9 +668,10 @@ class RemBertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RemBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None REMBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index f2d207c218..e6b1d85b2a 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -283,9 +283,10 @@ class ResNetPreTrainedModel(PreTrainedModel): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ResNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None RESNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 6d4cc991d2..32a19c0883 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -510,20 +510,15 @@ class RobertaEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -612,9 +607,10 @@ class RobertaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index da1cd6331b..78ca206845 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -512,20 +512,15 @@ class RobertaPreLayerNormEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -615,9 +610,10 @@ class RobertaPreLayerNormPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RobertaPreLayerNormEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROBERTA_PRELAYERNORM_START_DOCSTRING = r""" diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index a5b1b63050..3a58efa914 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -644,20 +644,15 @@ class RoCBertEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -796,9 +791,10 @@ class RoCBertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RoCBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROC_BERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index b9c36a305f..3893e27b02 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -578,21 +578,16 @@ class RoFormerEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, sinusoidal_pos, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -715,9 +710,10 @@ class RoFormerPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RoFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index db41bd3c95..2752333213 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -466,9 +466,10 @@ class RwkvPreTrainedModel(PreTrainedModel): module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RwkvModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass @@ -676,16 +677,8 @@ class RwkvModel(RwkvPreTrainedModel): all_hidden_states = () if output_hidden_states else None for idx, block in enumerate(self.blocks): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - hidden_states, state, attentions = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), hidden_states, state + hidden_states, state, attentions = self.gradient_checkpointing_func( + block.__call__, hidden_states, state, use_cache, output_attentions ) else: hidden_states, state, attentions = block( diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index abf5544a5b..1bd6fcdc2a 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1042,15 +1042,8 @@ class SamVisionEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, ) else: diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 91ec6a8f9b..ea79c73418 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -892,15 +892,8 @@ class SeamlessM4TConformerEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, relative_position_embeddings, @@ -1547,9 +1540,10 @@ class SeamlessM4TPreTrainedModel(PreTrainedModel): k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TSpeechEncoder)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TConformerEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): kernel_size, stride = self.config.adaptor_kernel_size, self.config.adaptor_stride @@ -1856,18 +1850,12 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2130,16 +2118,8 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -2147,6 +2127,8 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 34f9c84235..36416c168c 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -360,15 +360,8 @@ class SEWFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -673,17 +666,11 @@ class SEWEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -756,9 +743,10 @@ class SEWPreTrainedModel(PreTrainedModel): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SEWEncoder, SEWFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 661a8c03b1..39c9641b94 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -453,15 +453,8 @@ class SEWDFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -1134,20 +1127,14 @@ class SEWDTransformerEncoder(nn.Module): all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - output_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + output_states = self.gradient_checkpointing_func( + layer_module.__call__, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: output_states = layer_module( @@ -1322,9 +1309,10 @@ class SEWDPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, SEWDTransformerEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (SEWDEncoder, SEWDFeatureEncoder, SEWDTransformerEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SEWD_START_DOCSTRING = r""" diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index e80c26e269..ec255fab9b 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -249,10 +249,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel): f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 31c9b6cfe9..73a02fe66d 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -559,9 +559,10 @@ class Speech2TextPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Speech2TextDecoder, Speech2TextEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -817,18 +818,12 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1065,16 +1060,8 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1082,6 +1069,8 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index f9b5dec420..acee2b15a4 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -437,9 +437,10 @@ class Speech2Text2PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Speech2Text2Decoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SPEECH_TO_TEXT_2_START_DOCSTRING = r""" @@ -669,16 +670,8 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 9b8ab3d380..b8fea79664 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -520,15 +520,8 @@ class SpeechT5FeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -1281,9 +1274,10 @@ class SpeechT5PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SpeechT5Encoder, SpeechT5Decoder, SpeechT5FeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class SpeechT5Encoder(SpeechT5PreTrainedModel): @@ -1386,19 +1380,13 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), position_bias, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1439,7 +1427,6 @@ class SpeechT5EncoderWithSpeechPrenet(SpeechT5PreTrainedModel): super().__init__(config) self.prenet = SpeechT5SpeechEncoderPrenet(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1476,7 +1463,6 @@ class SpeechT5EncoderWithTextPrenet(SpeechT5PreTrainedModel): super().__init__(config) self.prenet = SpeechT5TextEncoderPrenet(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1519,7 +1505,6 @@ class SpeechT5EncoderWithoutPrenet(SpeechT5PreTrainedModel): def __init__(self, config: SpeechT5Config): super().__init__(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1715,16 +1700,8 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1732,6 +1709,8 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( @@ -1788,7 +1767,6 @@ class SpeechT5DecoderWithSpeechPrenet(SpeechT5PreTrainedModel): super().__init__(config) self.prenet = SpeechT5SpeechDecoderPrenet(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1836,7 +1814,6 @@ class SpeechT5DecoderWithTextPrenet(SpeechT5PreTrainedModel): super().__init__(config) self.prenet = SpeechT5TextDecoderPrenet(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1889,7 +1866,6 @@ class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel): def __init__(self, config: SpeechT5Config): super().__init__(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index f72ffb1011..1bdf8f3f5f 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -459,20 +459,15 @@ class SplinterEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -544,9 +539,10 @@ class SplinterPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, SplinterEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SPLINTER_START_DOCSTRING = r""" diff --git a/src/transformers/models/swiftformer/modeling_swiftformer.py b/src/transformers/models/swiftformer/modeling_swiftformer.py index ff72f87506..4170ce153b 100644 --- a/src/transformers/models/swiftformer/modeling_swiftformer.py +++ b/src/transformers/models/swiftformer/modeling_swiftformer.py @@ -442,9 +442,10 @@ class SwiftFormerPreTrainedModel(PreTrainedModel): nn.init.constant_(module.bias, 0) nn.init.constant_(module.weight, 1.0) - def _set_gradient_checkpointing(self, module: SwiftFormerEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: SwiftFormerEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, SwiftFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIFTFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 45a7aa718c..c2f15dbbf2 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -825,15 +825,8 @@ class SwinEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -901,9 +894,10 @@ class SwinPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, SwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 02ec39edb0..5d53561442 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -951,11 +951,6 @@ class TFSwinPreTrainedModel(TFPreTrainedModel): config_class = SwinConfig base_model_prefix = "swin" main_input_name = "pixel_values" - supports_gradient_checkpointing = True - - def _set_gradient_checkpointing(self, module, value=False) -> None: - if isinstance(module, TFSwinEncoder): - module.gradient_checkpointing = value SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index a8a17bdf58..47ce01d169 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -746,15 +746,8 @@ class Swin2SREncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(stage_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + stage_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) @@ -802,9 +795,10 @@ class Swin2SRPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Swin2SREncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN2SR_START_DOCSTRING = r""" diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index a4224e16df..6daad938a6 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -906,15 +906,8 @@ class Swinv2Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -983,9 +976,10 @@ class Swinv2PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Swinv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWINV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 0a402ea2d6..32d030728d 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -865,9 +865,10 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SwitchTransformersAttention, SwitchTransformersStack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1039,15 +1040,8 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1057,6 +1051,8 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0e7237ea36..c796a9cf24 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -873,9 +873,10 @@ class T5PreTrainedModel(PreTrainedModel): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (T5Attention, T5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1100,15 +1101,8 @@ class T5Stack(T5PreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1118,6 +1112,8 @@ class T5Stack(T5PreTrainedModel): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index b6012700ee..e1da557b00 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -837,9 +837,10 @@ class TableTransformerPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TableTransformerDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TABLE_TRANSFORMER_START_DOCSTRING = r""" @@ -1149,15 +1150,8 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel): continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index cdaa4b3e27..de05d77ec9 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -646,20 +646,15 @@ class TapasEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_values, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_values, + output_attentions, ) else: layer_outputs = layer_module( @@ -778,9 +773,10 @@ class TapasPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TapasEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TAPAS_START_DOCSTRING = r""" diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 2caca5bd10..1fa6a963f5 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -663,9 +663,10 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (TimeSeriesTransformerDecoder, TimeSeriesTransformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TIME_SERIES_TRANSFORMER_START_DOCSTRING = r""" @@ -946,18 +947,12 @@ class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1163,16 +1158,8 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1180,6 +1167,8 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 676bcf7a5e..044705c35e 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -439,16 +439,10 @@ class TimesformerEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, output_attentions) @@ -494,9 +488,10 @@ class TimesformerPreTrainedModel(PreTrainedModel): nn.init.trunc_normal_(module.position_embeddings, std=self.config.initializer_range) module.patch_embeddings.apply(self._init_weights) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TimesformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TIMESFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index c0541814be..ada8638a03 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -454,9 +454,10 @@ class TrOCRPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TrOCRDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TROCR_START_DOCSTRING = r""" @@ -701,16 +702,8 @@ class TrOCRDecoder(TrOCRPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -718,6 +711,8 @@ class TrOCRDecoder(TrOCRPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 464c3e76a1..a37265f37c 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -560,18 +560,12 @@ class TvltEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -616,9 +610,10 @@ class TvltPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, TvltEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (TvltEncoder, TvltDecoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TVLT_START_DOCSTRING = r""" @@ -877,17 +872,11 @@ class TvltDecoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index ffafd15811..a5b58444fe 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -556,9 +556,10 @@ class UMT5PreTrainedModel(PreTrainedModel): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UMT5Attention, UMT5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -709,15 +710,8 @@ class UMT5Stack(UMT5PreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, encoder_hidden_states, @@ -725,6 +719,8 @@ class UMT5Stack(UMT5PreTrainedModel): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index c475ab7f80..db14d5bca5 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -384,15 +384,8 @@ class UniSpeechFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -767,17 +760,11 @@ class UniSpeechEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -857,17 +844,11 @@ class UniSpeechEncoderStableLayerNorm(nn.Module): # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -1039,9 +1020,10 @@ class UniSpeechPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UniSpeechEncoder, UniSpeechEncoderStableLayerNorm, UniSpeechFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UNISPEECH_START_DOCSTRING = r""" diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 3fcc9549bb..8a9a63804b 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -398,15 +398,8 @@ class UniSpeechSatFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -781,17 +774,11 @@ class UniSpeechSatEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -871,17 +858,11 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module): # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -1053,9 +1034,10 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UniSpeechSatEncoder, UniSpeechSatEncoderStableLayerNorm, UniSpeechSatFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UNISPEECH_SAT_START_DOCSTRING = r""" diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index b56b508d14..04b8c94e13 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -315,9 +315,10 @@ class UperNetPreTrainedModel(PreTrainedModel): if self.auxiliary_head is not None: self.auxiliary_head.init_weights() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BackboneMixin): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UPERNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 07c32d1492..277280954f 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -434,17 +434,11 @@ class VideoMAEEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -489,9 +483,10 @@ class VideoMAEPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, VideoMAEEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (VideoMAEEncoder, VideoMAEDecoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIDEOMAE_START_DOCSTRING = r""" @@ -726,17 +721,11 @@ class VideoMAEDecoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index a36d58bd23..482bd08359 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -531,18 +531,12 @@ class ViltEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -591,9 +585,10 @@ class ViltPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ViltEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VILT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index d3e464cbff..84275cc33a 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -225,10 +225,10 @@ class VisionEncoderDecoderModel(PreTrainedModel): f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 81ad106848..425a125a0b 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -418,18 +418,12 @@ class VisualBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -547,9 +541,10 @@ class VisualBertPreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VisualBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 8fdacdddf0..67dbddf876 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -397,17 +397,11 @@ class ViTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -467,9 +461,10 @@ class ViTPreTrainedModel(PreTrainedModel): std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 008f6b3c9d..959522843f 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -415,17 +415,11 @@ class ViTHybridEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -486,9 +480,10 @@ class ViTHybridPreTrainedModel(PreTrainedModel): std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTHybridEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTHybridEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index ef0c7c9f36..e156fdc329 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -536,17 +536,11 @@ class ViTMAEEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -591,9 +585,10 @@ class ViTMAEPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, ViTMAEEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (ViTMAEEncoder, ViTMAEDecoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_MAE_START_DOCSTRING = r""" @@ -793,17 +788,11 @@ class ViTMAEDecoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 46639e7d62..b727c331cf 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -387,17 +387,11 @@ class ViTMSNEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -444,9 +438,10 @@ class ViTMSNPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: ViTMSNEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTMSNEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTMSNEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_MSN_START_DOCSTRING = r""" diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index e89fdbd7a3..9bb3991fab 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -565,17 +565,11 @@ class VitDetEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -666,9 +660,10 @@ class VitDetPreTrainedModel(PreTrainedModel): module.norm3.weight.data.zero_() module.norm3.bias.data.zero_() - def _set_gradient_checkpointing(self, module: VitDetEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: VitDetEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, VitDetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VITDET_START_DOCSTRING = r""" diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index b23bdd21d5..f5025a37e7 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -86,9 +86,15 @@ class VitMattePreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BackboneMixin): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None + + for backbone_module in module.modules(): + if hasattr(backbone_module, "gradient_checkpointing"): + backbone_module.gradient_checkpointing_func = gradient_checkpointing_func + backbone_module.gradient_checkpointing = gradient_checkpointing_func is not None class VitMatteBasicConv3x3(nn.Module): diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 49b9a1f1ae..b621bde35e 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1167,18 +1167,12 @@ class VitsEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, padding_mask, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1296,9 +1290,10 @@ class VitsPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (VitsTextEncoder)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, VitsEncoder): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VITS_START_DOCSTRING = r""" diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index fd35668572..50cb82fb4e 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -338,17 +338,11 @@ class VivitEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -414,9 +408,10 @@ class VivitPreTrainedModel(PreTrainedModel): elif isinstance(module, nn.Parameter): module.data.normal_(mean=0.0, std=self.config.initializer_range) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VivitEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index a6e02a0476..9f48e52962 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -451,15 +451,8 @@ class Wav2Vec2FeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -803,17 +796,11 @@ class Wav2Vec2Encoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -892,17 +879,11 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -1173,9 +1154,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_adapters(self): if self.config.adapter_attn_dim is None: diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index f162c51429..5fba773ee0 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -518,15 +518,8 @@ class Wav2Vec2ConformerFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -911,18 +904,12 @@ class Wav2Vec2ConformerEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, relative_position_embeddings, + output_attentions, ) else: layer_outputs = layer( @@ -1178,9 +1165,10 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None WAV2VEC2_CONFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 5013837cbd..55b19e4c41 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -354,15 +354,8 @@ class WavLMFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -713,18 +706,12 @@ class WavLMEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, position_bias, + output_attentions, ) else: layer_outputs = layer( @@ -804,18 +791,12 @@ class WavLMEncoderStableLayerNorm(nn.Module): # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, position_bias, + output_attentions, ) else: layer_outputs = layer( @@ -1052,9 +1033,10 @@ class WavLMPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None WAVLM_START_DOCSTRING = r""" diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8962324471..d6d0302727 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -685,9 +685,10 @@ class WhisperPreTrainedModel(PreTrainedModel): embed_positions = module.embed_positions.weight embed_positions.copy_(sinusoids(*embed_positions.shape)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (WhisperDecoder, WhisperEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -942,18 +943,12 @@ class WhisperEncoder(WhisperPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, None, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1174,16 +1169,8 @@ class WhisperDecoder(WhisperPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1191,6 +1178,8 @@ class WhisperDecoder(WhisperPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, # past_key_value + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index da7eddff8d..6c9cc02db9 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -534,9 +534,10 @@ class XCLIPPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (XCLIPEncoder, XCLIPVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None X_CLIP_START_DOCSTRING = r""" @@ -703,18 +704,12 @@ class XCLIPEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -950,18 +945,12 @@ class XCLIPVisionEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 0c769dbbb5..1880a78321 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -503,9 +503,10 @@ class XGLMPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XGLMModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( @@ -674,16 +675,8 @@ class XGLMModel(XGLMPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -691,6 +684,8 @@ class XGLMModel(XGLMPreTrainedModel): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index cde05cfe8a..9a9f02b74a 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -570,9 +570,10 @@ class XLMProphetNetPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (XLMProphetNetDecoder, XLMProphetNetEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1349,18 +1350,12 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel): encoder_hidden_states = encoder_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1592,16 +1587,8 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, extended_attention_mask, encoder_hidden_states, @@ -1613,6 +1600,8 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): predict_relative_position_buckets, position_ids, None, + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index da454b1e33..da99b2806f 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -511,20 +511,15 @@ class XLMRobertaEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -614,9 +609,10 @@ class XLMRobertaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XLMRobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None XLM_ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 26e0361abd..49f7c07517 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -499,20 +499,15 @@ class XLMRobertaXLEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 28fddc2fdb..5f7b42f266 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -573,21 +573,16 @@ class XmodEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, lang_ids, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -680,9 +675,10 @@ class XmodPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) # Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel._set_gradient_checkpointing with Roberta->Xmod - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XmodEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def set_default_language(self, language: str): """ diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index e3cb02ceae..f6cbaecd01 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -492,17 +492,11 @@ class YolosEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -551,9 +545,10 @@ class YolosPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: YolosEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: YolosEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, YolosEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None YOLOS_START_DOCSTRING = r""" diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 5edd7f8835..8db66d2210 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -561,17 +561,11 @@ class YosoEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) @@ -668,9 +662,10 @@ class YosoPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, YosoEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None YOSO_START_DOCSTRING = r""" diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 02fcb7d2f5..0b5af845c9 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -544,19 +544,15 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -679,9 +675,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -2024,9 +2021,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -2312,18 +2310,12 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2551,15 +2543,8 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -2567,6 +2552,8 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 34f5bae374..7e1c471bad 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -349,10 +349,24 @@ class ModelTesterMixin: model.gradient_checkpointing_enable() self.assertTrue(model.is_gradient_checkpointing) + # Loop over all modules and check that relevant modules have gradient_checkpointing set to True + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertTrue( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True" + ) + # check disable works model.gradient_checkpointing_disable() self.assertFalse(model.is_gradient_checkpointing) + # Loop over all modules and check that relevant modules have gradient_checkpointing set to False + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertFalse( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" + ) + def test_save_load_fast_init_from_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if config.__class__ not in MODEL_MAPPING: @@ -569,6 +583,13 @@ class ModelTesterMixin: loss = model(**inputs).loss loss.backward() + model.gradient_checkpointing_disable() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + def test_attention_outputs(self): if not self.has_attentions: self.skipTest(reason="Model does not output attentions")