diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 9eadaa2198..6aaf187b04 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -34,7 +34,11 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 09ee6eca62..0a7238ef60 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -30,7 +30,12 @@ from ...modeling_outputs import ( BaseModelOutputWithPoolingAndNoAttention, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_start_docstrings, @@ -1100,7 +1105,7 @@ class AlignTextEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 26b3f59280..68a3a28a48 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -30,7 +30,12 @@ from ...modeling_outputs import ( BaseModelOutputWithPoolingAndProjection, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig @@ -651,7 +656,7 @@ class AltRobertaEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -965,7 +970,7 @@ class AltCLIPEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 0f8c045121..15b8c37935 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_audio_spectrogram_transformer import ASTConfig @@ -343,7 +343,7 @@ class ASTEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 70587add17..981df3ab84 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -34,6 +34,7 @@ from ...modeling_outputs import ( Seq2SeqTSPredictionOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_autoformer import AutoformerConfig @@ -1210,7 +1211,7 @@ class AutoformerEncoder(AutoformerPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1428,7 +1429,7 @@ class AutoformerDecoder(AutoformerPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 5045244902..9000ad3d06 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -35,6 +35,7 @@ from ...modeling_outputs import ( Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_end_docstrings, @@ -849,7 +850,7 @@ class BartEncoder(BartPretrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1105,7 +1106,7 @@ class BartDecoder(BartPretrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index b17721fb2b..b546f14001 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -34,7 +34,7 @@ from ...modeling_outputs import ( SemanticSegmenterOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -517,7 +517,7 @@ class BeitEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index fb92a0e84c..37f236d4a6 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -40,7 +40,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -598,7 +603,7 @@ class BertEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index f92b7a0633..f20503c594 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -25,7 +25,12 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -408,7 +413,7 @@ class BertEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index e1346a23c9..5e80d0423f 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -37,7 +37,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward +from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -1622,7 +1622,7 @@ class BigBirdEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 8d7906631d..1ab72f0b49 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -36,6 +36,7 @@ from ...modeling_outputs import ( Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_end_docstrings, @@ -1945,7 +1946,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -2291,7 +2292,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index a9ecb11a61..c29c13547e 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -32,6 +32,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -594,7 +595,7 @@ class BioGptModel(BioGptPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 8f2780772c..f96531f51f 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -36,6 +36,7 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_end_docstrings, add_start_docstrings, @@ -779,7 +780,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1034,7 +1035,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index ef8d51a2b0..b09dce88e0 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -34,6 +34,7 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_end_docstrings, add_start_docstrings, @@ -777,7 +778,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1031,7 +1032,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index f16b89b7a3..93bb26c5b9 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -25,6 +25,7 @@ from torch.nn.functional import normalize from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -620,7 +621,7 @@ class BlipEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 1f269cf852..38866578b6 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -34,6 +34,7 @@ from ...modeling_utils import ( find_pruneable_heads_and_indices, prune_linear_layer, ) +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_blip import BlipTextConfig @@ -427,7 +428,7 @@ class BlipTextEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 82a879771b..b326ff36c7 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -31,7 +31,12 @@ from ...modeling_outputs import ( BaseModelOutputWithPoolingAndCrossAttentions, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_start_docstrings, @@ -492,7 +497,7 @@ class Blip2Encoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -963,7 +968,7 @@ class Blip2QFormerEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 4f6de49a14..2144c43687 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -33,6 +33,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_bloom import BloomConfig @@ -775,7 +776,7 @@ class BloomModel(BloomPreTrainedModel): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, alibi, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 4290241fbc..37424e0354 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -32,8 +32,13 @@ from ...modeling_outputs import ( ModelOutput, SequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig @@ -810,7 +815,7 @@ class BridgeTowerTextEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index e98840fbc6..25d11d24e1 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -35,7 +35,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -529,7 +534,7 @@ class CamembertEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index a91d42f039..8406a9d1d4 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -36,7 +36,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -800,7 +805,7 @@ class CanineEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 0adf5cfdcb..975857024e 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -31,7 +31,12 @@ from ...modeling_outputs import ( BaseModelOutputWithPoolingAndCrossAttentions, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -914,7 +919,7 @@ class ChineseCLIPTextEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1023,7 +1028,7 @@ class ChineseCLIPVisionEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, ) diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index c4dbcb03f3..fa83670006 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -30,7 +30,13 @@ from ...modeling_outputs import ( BaseModelOutputWithPoolingAndCrossAttentions, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + meshgrid, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_start_docstrings, @@ -947,7 +953,7 @@ class ClapAudioEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: @@ -1601,7 +1607,7 @@ class ClapTextEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index ee9d660ef7..6a96715a27 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -25,6 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -644,7 +645,7 @@ class CLIPEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 85b1196530..fc37277c34 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -26,6 +26,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -654,7 +655,7 @@ class CLIPSegEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 8b1d34f59e..7cee097b2b 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_codegen import CodeGenConfig @@ -549,7 +550,7 @@ class CodeGenModel(CodeGenPreTrainedModel): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 023cb27848..44d9cc9bb5 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -26,6 +26,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -1395,7 +1396,7 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index bbdba210c2..49923ba123 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -35,7 +35,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel, SequenceSummary -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_convbert import ConvBertConfig @@ -639,7 +644,7 @@ class ConvBertEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py index 99e3a02feb..8784fff414 100644 --- a/src/transformers/models/cvt/modeling_cvt.py +++ b/src/transformers/models/cvt/modeling_cvt.py @@ -26,7 +26,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging from .configuration_cvt import CvtConfig diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 168f342acd..72a53c292c 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -35,6 +35,7 @@ from ...modeling_outputs import ( XVectorOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_data2vec_audio import Data2VecAudioConfig @@ -300,7 +301,7 @@ class Data2VecAudioFeatureEncoder(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -600,7 +601,7 @@ class Data2VecAudioEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 206fe1603b..45c182a95c 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -34,7 +34,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -515,7 +520,7 @@ class Data2VecTextEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 77b4243548..cbef81d2a8 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -33,7 +33,7 @@ from ...modeling_outputs import ( SemanticSegmenterOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -529,7 +529,7 @@ class Data2VecVisionEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 9a0d43db3a..260e713d5b 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -31,7 +31,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import softmax_backward_data +from ...pytorch_utils import softmax_backward_data, torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_deberta import DebertaConfig @@ -464,7 +464,7 @@ class DebertaEncoder(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(layer_module), next_kv, attention_mask, diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 1596ad4ffa..22aef35924 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -32,7 +32,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import softmax_backward_data +from ...pytorch_utils import softmax_backward_data, torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_deberta_v2 import DebertaV2Config @@ -508,7 +508,7 @@ class DebertaV2Encoder(nn.Module): return custom_forward - output_states = torch.utils.checkpoint.checkpoint( + output_states = torch_custom_checkpointing( create_custom_forward(layer_module), next_kv, attention_mask, diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 926947b161..64d64191c4 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -27,7 +27,7 @@ from torch.cuda.amp import autocast from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -643,7 +643,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 6469cf7a65..fc195622c2 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -41,7 +41,7 @@ from ...file_utils import ( ) from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import meshgrid +from ...pytorch_utils import meshgrid, torch_custom_checkpointing from ...utils import is_ninja_available, logging from ..auto import AutoBackbone from .configuration_deformable_detr import DeformableDetrConfig @@ -1380,7 +1380,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, encoder_hidden_states, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 8b03835812..4c5491935c 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -33,7 +33,7 @@ from ...modeling_outputs import ( MaskedImageModelingOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -364,7 +364,7 @@ class DeiTEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index af218829d6..67427b4f41 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -36,7 +36,7 @@ from ...file_utils import ( ) from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import meshgrid +from ...pytorch_utils import meshgrid, torch_custom_checkpointing from ...utils import is_torchvision_available, logging, requires_backends from ..auto import AutoBackbone from .configuration_deta import DetaConfig @@ -1272,7 +1272,7 @@ class DetaDecoder(DetaPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, encoder_hidden_states, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index c92c43e46d..684129663f 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -26,6 +26,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -1130,7 +1131,7 @@ class DetrDecoder(DetrPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 65c48eb81f..07f9fee14e 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -28,7 +28,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -756,7 +756,7 @@ class DonutSwinEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 187a6c3665..0630a3c48b 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -39,7 +39,7 @@ from ...file_utils import ( ) from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ModelOutput, logging from ..auto import AutoBackbone from .configuration_dpt import DPTConfig @@ -535,7 +535,7 @@ class DPTViTEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index a7ee4ec932..3197e060bd 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -36,7 +36,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel, SequenceSummary -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -576,7 +581,7 @@ class ElectraEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index b8df1b2d50..a5f16a3a86 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -38,7 +38,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -511,7 +516,7 @@ class ErnieEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index e0b26e0f78..27b7bb2d91 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -30,7 +30,8 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import logging from .configuration_esm import EsmConfig @@ -610,7 +611,7 @@ class EsmEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 5d49197f8c..0f85ff06f5 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -26,7 +26,8 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -668,7 +669,7 @@ class FlavaEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 6bc526eeeb..8d8de88c8f 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -43,7 +43,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward +from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -297,7 +297,7 @@ class FNetEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module), hidden_states) + layer_outputs = torch_custom_checkpointing(create_custom_forward(layer_module), hidden_states) else: layer_outputs = layer_module(hidden_states) diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index fc327ad0b3..9e8efed443 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -593,7 +594,7 @@ class FocalNetEncoder(nn.Module): return custom_forward - stage_outputs = torch.utils.checkpoint.checkpoint( + stage_outputs = torch_custom_checkpointing( create_custom_forward(stage_module), hidden_states, input_dimensions, diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 23ae6d6496..83bf591fdb 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -34,7 +34,12 @@ from ...modeling_outputs import ( CausalLMOutputWithPast, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_git import GitConfig, GitVisionConfig @@ -457,7 +462,7 @@ class GitEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -883,7 +888,7 @@ class GitVisionEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index b9a8568f00..dab2c61353 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -35,8 +35,12 @@ from ...modeling_outputs import ( SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, SequenceSummary -from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from ...modeling_utils import Conv1D, PreTrainedModel, SequenceSummary +from ...pytorch_utils import ( + find_pruneable_heads_and_indices, + prune_conv1d_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -890,7 +894,7 @@ class GPT2Model(GPT2PreTrainedModel): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 705d07b1da..cf23e1ba08 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -28,6 +28,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -661,7 +662,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index b67f4ddbfa..768893cb44 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -34,6 +34,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_gpt_neo import GPTNeoConfig @@ -613,7 +614,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 7c3bfd1035..3f7bbcdf64 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -36,6 +36,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_gpt_neox import GPTNeoXConfig @@ -557,7 +558,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index de12016798..4969bd7fd1 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -31,6 +31,7 @@ from ...modeling_outputs import ( SequenceClassifierOutputWithPast, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -677,7 +678,7 @@ class GPTJModel(GPTJPreTrainedModel): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index c19ebd13b9..e5ee94adbd 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -28,6 +28,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -1037,7 +1038,7 @@ class GroupViTTextEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 70a8c07940..774c4826c9 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -353,7 +354,7 @@ class HubertFeatureEncoder(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -738,7 +739,7 @@ class HubertEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, @@ -828,7 +829,7 @@ class HubertEncoderStableLayerNorm(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 539119fabf..31b911431f 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -32,7 +32,7 @@ from ...modeling_outputs import ( SequenceClassifierOutputWithPast, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer, torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_imagegpt import ImageGPTConfig @@ -826,7 +826,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 2bf3f208a9..4774f1d91d 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -30,6 +30,7 @@ from ...modeling_outputs import ( Seq2SeqTSPredictionOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_informer import InformerConfig @@ -1217,14 +1218,14 @@ class InformerEncoder(InformerPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), ) if conv_layer is not None: - output = torch.utils.checkpoint.checkpoint(conv_layer, layer_outputs[0]) + output = torch_custom_checkpointing(conv_layer, layer_outputs[0]) layer_outputs = (output,) + layer_outputs[1:] else: layer_outputs = encoder_layer( @@ -1440,7 +1441,7 @@ class InformerDecoder(InformerPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 410f765094..614bebe121 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -33,7 +33,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_layoutlm import LayoutLMConfig @@ -492,7 +497,7 @@ class LayoutLMEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 5a6f39ce31..0e0f2c1bd8 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -31,7 +31,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward +from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -455,7 +455,7 @@ class LayoutLMv2Encoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index db6618caae..31fb1f6fb5 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -32,7 +32,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward +from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_layoutlmv3 import LayoutLMv3Config @@ -671,7 +671,7 @@ class LayoutLMv3Encoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index a11659e389..8fa8c00aad 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -35,6 +35,7 @@ from ...modeling_outputs import ( Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -1884,7 +1885,7 @@ class LEDEncoder(LEDPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -2150,7 +2151,7 @@ class LEDDecoder(LEDPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 74454d244e..1953992d05 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -31,7 +31,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_lilt import LiltConfig @@ -519,7 +524,7 @@ class LiltEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layout_inputs, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c9debdd252..2468f7088b 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_llama import LlamaConfig @@ -568,7 +569,7 @@ class LlamaModel(LlamaPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 665e2cb564..809d889eed 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -25,7 +25,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -1311,7 +1316,7 @@ class LongformerEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 1a49444e8a..d1358a78d8 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -23,7 +23,6 @@ from typing import Any, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -33,7 +32,12 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -1517,7 +1521,7 @@ class LongT5Stack(LongT5PreTrainedModel): return custom_forward - layer_outputs = checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index ba21d3deb3..0f217909b0 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -26,7 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward +from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -795,7 +795,7 @@ class LukeEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), word_hidden_states, entity_hidden_states, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index f8f9e1d3a8..db8e017d17 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -32,6 +32,7 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_end_docstrings, @@ -827,7 +828,7 @@ class M2M100Encoder(M2M100PreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1074,7 +1075,7 @@ class M2M100Decoder(M2M100PreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index a75f833fb5..15d58baeda 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -35,6 +35,7 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_end_docstrings, add_start_docstrings, @@ -790,7 +791,7 @@ class MarianEncoder(MarianPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1039,7 +1040,7 @@ class MarianDecoder(MarianPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 0c6847b478..0792ff1b72 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -43,6 +43,7 @@ from ...modeling_utils import ( find_pruneable_heads_and_indices, prune_linear_layer, ) +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_markuplm import MarkupLMConfig @@ -653,7 +654,7 @@ class MarkupLMEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 4cb2493e58..2fd61a179b 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -36,6 +36,7 @@ from ...file_utils import ( ) from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_mask2former import Mask2FormerConfig @@ -1875,7 +1876,7 @@ class Mask2FormerMaskedAttentionDecoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 830f8b23c8..2b91e975ce 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -28,6 +28,7 @@ from ... import AutoBackbone from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -776,7 +777,7 @@ class DetrDecoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 7016b598e8..e22f466edc 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...file_utils import ModelOutput from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils.backbone_utils import BackboneMixin from .configuration_maskformer_swin import MaskFormerSwinConfig @@ -695,7 +695,7 @@ class MaskFormerSwinEncoder(nn.Module): return custom_forward - layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint( + layer_hidden_states, output_dimensions, layer_all_hidden_states = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask ) else: diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 67750ab42f..660177708a 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -34,6 +34,7 @@ from ...modeling_outputs import ( Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_end_docstrings, @@ -831,7 +832,7 @@ class MBartEncoder(MBartPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1089,7 +1090,7 @@ class MBartDecoder(MBartPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/mctct/modeling_mctct.py b/src/transformers/models/mctct/modeling_mctct.py index 08e280b3cc..22838d4e28 100755 --- a/src/transformers/models/mctct/modeling_mctct.py +++ b/src/transformers/models/mctct/modeling_mctct.py @@ -33,6 +33,7 @@ from ...modeling_utils import ( find_pruneable_heads_and_indices, prune_linear_layer, ) +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_mctct import MCTCTConfig @@ -623,7 +624,7 @@ class MCTCTEncoder(MCTCTPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index bba7e7369c..9a24f41e70 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -40,7 +40,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -556,7 +561,7 @@ class MegatronBertEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 3503e86c9c..e68357e6d3 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -33,7 +33,7 @@ from ...modeling_outputs import ( SemanticSegmenterOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -633,7 +633,7 @@ class MobileViTEncoder(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, ) diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index b8c071a74f..bd2f2bd9cf 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -32,6 +32,7 @@ from ...modeling_outputs import ( SemanticSegmenterOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -589,7 +590,7 @@ class MobileViTV2Encoder(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, ) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index a3cfce8ffc..ce5c81f634 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -33,7 +32,7 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -1046,7 +1045,7 @@ class MT5Stack(MT5PreTrainedModel): return custom_forward - layer_outputs = checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 6a44768d8e..4f905a7b51 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -34,6 +34,7 @@ from ...modeling_outputs import ( Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_end_docstrings, @@ -953,7 +954,7 @@ class MvpEncoder(MvpPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1231,7 +1232,7 @@ class MvpDecoder(MvpPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index 97c5b5a90e..68a78a64fa 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -38,7 +38,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -584,7 +589,7 @@ class NezhaEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 06b61c7497..d67032d119 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -22,7 +22,6 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled @@ -33,6 +32,7 @@ from ...modeling_outputs import ( Seq2SeqMoEOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_end_docstrings, add_start_docstrings, @@ -1155,7 +1155,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1428,7 +1428,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): return custom_forward - layer_outputs = checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index b859b0db1d..6bbd95e709 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -33,7 +33,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_nystromformer import NystromformerConfig @@ -375,7 +380,7 @@ class NystromformerEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index a874611acd..1e2c59a717 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -28,6 +28,7 @@ from ... import AutoBackbone from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -2619,7 +2620,7 @@ class OneFormerTextTransformer(nn.Module): def forward(self, hidden_states: torch.Tensor): for layer in self.layers: if self.use_checkpoint: - hidden_states = torch.utils.checkpoint.checkpoint(layer, hidden_states) + hidden_states = torch_custom_checkpointing(layer, hidden_states) else: hidden_states = layer(hidden_states) return hidden_states diff --git a/src/transformers/models/open_llama/modeling_open_llama.py b/src/transformers/models/open_llama/modeling_open_llama.py index 16ad554dc3..07b19a808d 100644 --- a/src/transformers/models/open_llama/modeling_open_llama.py +++ b/src/transformers/models/open_llama/modeling_open_llama.py @@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_open_llama import OpenLlamaConfig @@ -603,7 +604,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index bd64630c62..79b555a5e3 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -29,6 +29,7 @@ from ...modeling_outputs import ( SequenceClassifierOutputWithPast, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -700,7 +701,7 @@ class OPTDecoder(OPTPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, causal_attention_mask, diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index f65a068857..6f43ea603e 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -27,6 +27,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -754,7 +755,7 @@ class OwlViTEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index a2bd3f3812..3af479971e 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -34,6 +34,7 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_end_docstrings, add_start_docstrings, @@ -805,7 +806,7 @@ class PegasusEncoder(PegasusPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1089,7 +1090,7 @@ class PegasusDecoder(PegasusPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 8e380a4de5..94fc2d25dd 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -33,6 +33,7 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_end_docstrings, add_start_docstrings, @@ -1072,7 +1073,7 @@ class PegasusXEncoder(PegasusXPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, global_hidden_states, @@ -1330,7 +1331,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 2db104a5a1..0834bbeaaf 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -31,7 +30,7 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...pytorch_utils import ALL_LAYERNORM_LAYERS, torch_custom_checkpointing from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -350,7 +349,7 @@ class Pix2StructVisionEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1502,7 +1501,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): return custom_forward - layer_outputs = checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 365429360a..23a9f928d1 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -33,6 +33,7 @@ from ...modeling_outputs import ( Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_end_docstrings, @@ -810,7 +811,7 @@ class PLBartEncoder(PLBartPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1067,7 +1068,7 @@ class PLBartDecoder(PLBartPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 9160d5e1eb..007d9aadf2 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -28,6 +28,7 @@ from torch.nn import LayerNorm from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -1336,7 +1337,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, extended_attention_mask, @@ -1577,7 +1578,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index 47a34e9590..d4371c0efb 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -39,7 +39,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -586,7 +586,7 @@ class QDQBertEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index f68fc04105..09d6bb7325 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -31,7 +31,12 @@ from ...modeling_outputs import ( ModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_realm import RealmConfig @@ -591,7 +596,7 @@ class RealmEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index da4ad96085..06da821d4d 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -36,7 +36,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -548,7 +553,7 @@ class RemBertEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/retribert/modeling_retribert.py b/src/transformers/models/retribert/modeling_retribert.py index 240d9476e7..e1397d39ce 100644 --- a/src/transformers/models/retribert/modeling_retribert.py +++ b/src/transformers/models/retribert/modeling_retribert.py @@ -21,10 +21,10 @@ import math from typing import Optional import torch -import torch.utils.checkpoint as checkpoint from torch import nn from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, logging from ..bert.modeling_bert import BertModel from .configuration_retribert import RetriBertConfig @@ -141,7 +141,7 @@ class RetriBertModel(RetriBertPreTrainedModel): for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)): b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] - pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask) + pooled_output = torch_custom_checkpointing(partial_encode, b_embedding_output, b_attention_mask) pooled_output_list.append(pooled_output) return torch.cat(pooled_output_list, dim=0) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index b0f1369246..f86fa4aa80 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -35,7 +35,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -515,7 +520,7 @@ class RobertaEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index b1e02e27f1..01276cd071 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -35,7 +35,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -517,7 +522,7 @@ class RobertaPreLayerNormEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 7647c14a9e..63abc9d4aa 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -35,7 +35,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -649,7 +654,7 @@ class RoCBertEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index b966bf4490..586ecbd2da 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -36,7 +36,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel, SequenceSummary -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -585,7 +590,7 @@ class RoFormerEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index c3cbaa9176..0e4177d90b 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -28,6 +28,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig @@ -1049,7 +1050,7 @@ class SamVisionEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, ) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index dd854c49f5..75ad1f97df 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_sew import SEWConfig @@ -367,7 +368,7 @@ class SEWFeatureEncoder(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -680,7 +681,7 @@ class SEWEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 7f7c1977d6..b7acb306bb 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -29,7 +29,7 @@ from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import softmax_backward_data +from ...pytorch_utils import softmax_backward_data, torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_sew_d import SEWDConfig @@ -460,7 +460,7 @@ class SEWDFeatureEncoder(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -1141,7 +1141,7 @@ class SEWDTransformerEncoder(nn.Module): return custom_forward - output_states = torch.utils.checkpoint.checkpoint( + output_states = torch_custom_checkpointing( create_custom_forward(layer_module), next_kv, attention_mask, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index d8a19084eb..3e2024dc69 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -31,6 +31,7 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_speech_to_text import Speech2TextConfig @@ -820,7 +821,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1068,7 +1069,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index c13b04642d..12e8d4592a 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, logging, replace_return_docstrings from .configuration_speech_to_text_2 import Speech2Text2Config @@ -677,7 +678,7 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 3e8ce5a23b..5988607f1c 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -35,6 +35,7 @@ from ...modeling_outputs import ( Seq2SeqSpectrogramOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_speecht5 import SpeechT5Config, SpeechT5HifiGanConfig @@ -528,7 +529,7 @@ class SpeechT5FeatureEncoder(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -1394,7 +1395,7 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1723,7 +1724,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 6e636fb695..88d6a480b7 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -27,7 +27,12 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_splinter import SplinterConfig @@ -464,7 +469,7 @@ class SplinterEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index b324cfdcd9..93144c66a9 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -832,7 +832,7 @@ class SwinEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index cd58b70650..6b1b803345 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -27,7 +27,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageSuperResolutionOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -753,7 +753,7 @@ class Swin2SREncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(stage_module), hidden_states, input_dimensions, layer_head_mask ) else: diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 97b460479d..07dd0a79b7 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -908,7 +908,7 @@ class Swinv2Encoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 008e23531a..1378ec9a98 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -33,7 +32,12 @@ from ...modeling_outputs import ( Seq2SeqMoEOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -1075,7 +1079,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): return custom_forward - layer_outputs = checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 050309fa9a..4531214e19 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -24,7 +24,6 @@ from typing import Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -34,7 +33,12 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -1074,7 +1078,7 @@ class T5Stack(T5PreTrainedModel): return custom_forward - layer_outputs = checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 733ff7b9b4..998f21a286 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -26,6 +26,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -1074,7 +1075,7 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 1621653f3e..4f736a367e 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -34,6 +34,7 @@ from ...pytorch_utils import ( find_pruneable_heads_and_indices, is_torch_greater_or_equal_than_1_12, prune_linear_layer, + torch_custom_checkpointing, ) from ...utils import ( ModelOutput, @@ -653,7 +654,7 @@ class TapasEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 8986ef6729..e3e0b3055d 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -31,6 +31,7 @@ from ...modeling_outputs import ( Seq2SeqTSPredictionOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_time_series_transformer import TimeSeriesTransformerConfig @@ -949,7 +950,7 @@ class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1166,7 +1167,7 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 9f886b6ece..5ff5bd7fd1 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_timesformer import TimesformerConfig @@ -446,7 +447,7 @@ class TimesformerEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, ) diff --git a/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py index e8ecedccb5..1027bd73f3 100644 --- a/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import functional as F from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -556,7 +557,7 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, layer_past, diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 6276c68a42..e8ee10f7de 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, logging, replace_return_docstrings from .configuration_trocr import TrOCRConfig @@ -709,7 +710,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 3725c5e772..4b990cdb03 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -567,7 +567,7 @@ class TvltEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -884,7 +884,7 @@ class TvltDecoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, None, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index e068fa59e5..5bd1af95c7 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -391,7 +392,7 @@ class UniSpeechFeatureEncoder(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -774,7 +775,7 @@ class UniSpeechEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, @@ -864,7 +865,7 @@ class UniSpeechEncoderStableLayerNorm(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 2ed8a5d572..f603d2712f 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -36,6 +36,7 @@ from ...modeling_outputs import ( XVectorOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -405,7 +406,7 @@ class UniSpeechSatFeatureEncoder(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -788,7 +789,7 @@ class UniSpeechSatEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, @@ -878,7 +879,7 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index c62d0c4632..5f44a5e4b3 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -30,7 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -441,7 +441,7 @@ class VideoMAEEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -724,7 +724,7 @@ class VideoMAEDecoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, None, diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 6ee1e396a6..5499a26cc7 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -38,6 +38,7 @@ from ...pytorch_utils import ( find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, + torch_custom_checkpointing, ) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_vilt import ViltConfig @@ -536,7 +537,7 @@ class ViltEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 0bef6e4af9..a73d6ac720 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -32,7 +32,12 @@ from ...modeling_outputs import ( SequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_start_docstrings, @@ -423,7 +428,7 @@ class VisualBertEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index bfd440caae..28ea5740ca 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -32,7 +32,7 @@ from ...modeling_outputs import ( MaskedImageModelingOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -404,7 +404,7 @@ class ViTEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 051d431946..ba3bbddf56 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ..auto import AutoBackbone from .configuration_vit_hybrid import ViTHybridConfig @@ -422,7 +422,7 @@ class ViTHybridEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index ef0c7c9f36..5a9c539fbc 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -29,7 +29,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -543,7 +543,7 @@ class ViTMAEEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -800,7 +800,7 @@ class ViTMAEDecoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, None, diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 46639e7d62..4f7b412fec 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_vit_msn import ViTMSNConfig @@ -394,7 +394,7 @@ class ViTMSNEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 43ab2408bb..9705a51e48 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -37,6 +37,7 @@ from ...modeling_outputs import ( XVectorOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -458,7 +459,7 @@ class Wav2Vec2FeatureEncoder(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -810,7 +811,7 @@ class Wav2Vec2Encoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, @@ -899,7 +900,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 3e37a4a504..86c0cbe5e2 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -35,6 +35,7 @@ from ...modeling_outputs import ( XVectorOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -523,7 +524,7 @@ class Wav2Vec2ConformerFeatureEncoder(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -916,7 +917,7 @@ class Wav2Vec2ConformerEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index e4072d9372..35dc46bac1 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -36,6 +36,7 @@ from ...modeling_outputs import ( XVectorOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_wavlm import WavLMConfig @@ -361,7 +362,7 @@ class WavLMFeatureEncoder(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -720,7 +721,7 @@ class WavLMEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, @@ -811,7 +812,7 @@ class WavLMEncoderStableLayerNorm(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 515b886f98..703607ad4a 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -34,6 +34,7 @@ from ...modeling_outputs import ( SequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -853,7 +854,7 @@ class WhisperEncoder(WhisperPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, None, @@ -1085,7 +1086,7 @@ class WhisperDecoder(WhisperPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 8db4ee0fd1..be6b281890 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -26,6 +26,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -708,7 +709,7 @@ class XCLIPEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -955,7 +956,7 @@ class XCLIPVisionEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 4a72b785a0..61b51d51fc 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_xglm import XGLMConfig @@ -683,7 +684,7 @@ class XGLMModel(XGLMPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index 2d14bfb6a7..fd90086672 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -29,6 +29,7 @@ from torch.nn import LayerNorm from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -1356,7 +1357,7 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, extended_attention_mask, @@ -1600,7 +1601,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index ae8d51a3f8..e00574239c 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -35,7 +35,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -516,7 +521,7 @@ class XLMRobertaEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index fb86717e1d..71f8de5a72 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -34,7 +34,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -504,7 +509,7 @@ class XLMRobertaXLEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index d99b77fedd..44e50bed3b 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -34,7 +34,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_xmod import XmodConfig @@ -578,7 +583,7 @@ class XmodEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, lang_ids, diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index e3cb02ceae..4b4aa01241 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -27,7 +27,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -499,7 +499,7 @@ class YolosEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 8c2ff9fa4e..1b1e6b13ad 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -34,7 +34,12 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_yoso import YosoConfig @@ -566,7 +571,7 @@ class YosoEncoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 4723c43035..520eeb8939 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -285,3 +285,18 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: non-overlapping lifetimes may have the same id. """ return tensor.device, storage_ptr(tensor), storage_size(tensor) + + +def torch_custom_checkpointing(*args): + r""" + A correct usage of `torch.utils.checkpoint.checkpoint` as the default call leads to silent bugs that leads to the + gradients of the last layers not being updated. For more in depth detail of the issue, please have a look at: + https://github.com/huggingface/transformers/pull/24247 + """ + kwargs = {} + if "use_reentrant" in list(inspect.signature(torch.utils.checkpoint.checkpoint).parameters): + kwargs["use_reentrant"] = False + return torch.utils.checkpoint.checkpoint( + *args, + **kwargs, + ) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 4899e19598..c5d141b1f4 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -43,6 +43,7 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...pytorch_utils import torch_custom_checkpointing from ...pytorch_utils import ( apply_chunking_to_forward, find_pruneable_heads_and_indices, @@ -550,7 +551,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1585,6 +1586,7 @@ from ...modeling_outputs import ( CausalLMOutputWithCrossAttentions ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config @@ -2318,7 +2320,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -2557,7 +2559,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/tests/models/align/test_modeling_align.py b/tests/models/align/test_modeling_align.py index 2357c20e21..c8ac69840f 100644 --- a/tests/models/align/test_modeling_align.py +++ b/tests/models/align/test_modeling_align.py @@ -352,6 +352,12 @@ class AlignTextModelTest(ModelTesterMixin, unittest.TestCase): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="ALIGN does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/altclip/test_modeling_altclip.py b/tests/models/altclip/test_modeling_altclip.py index 28213de84d..266e0c47b6 100755 --- a/tests/models/altclip/test_modeling_altclip.py +++ b/tests/models/altclip/test_modeling_altclip.py @@ -186,6 +186,12 @@ class AltCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="AltCLIPVisionModel has no base class and is not available in MODEL_MAPPING") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/autoformer/test_modeling_autoformer.py b/tests/models/autoformer/test_modeling_autoformer.py index 9f0434689c..ad006d9d07 100644 --- a/tests/models/autoformer/test_modeling_autoformer.py +++ b/tests/models/autoformer/test_modeling_autoformer.py @@ -238,6 +238,12 @@ class AutoformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa def test_resize_tokens_embeddings(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + # # Input is 'static_categorical_features' not 'input_ids' def test_model_main_input_name(self): model_signature = inspect.signature(getattr(AutoformerModel, "forward")) diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py index f9aa7339f7..149820023a 100644 --- a/tests/models/beit/test_modeling_beit.py +++ b/tests/models/beit/test_modeling_beit.py @@ -227,6 +227,12 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_multi_gpu_data_parallel_forward(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_model_common_attributes(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/big_bird/test_modeling_big_bird.py b/tests/models/big_bird/test_modeling_big_bird.py index f86c6d0ac7..45bff430bf 100644 --- a/tests/models/big_bird/test_modeling_big_bird.py +++ b/tests/models/big_bird/test_modeling_big_bird.py @@ -609,6 +609,12 @@ class BigBirdModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + # overwrite from common in order to skip the check on `attentions` def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): # `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version, diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 7d9c6b5ba5..a34efc0264 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -789,6 +789,12 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): def test_model_common_attributes(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/canine/test_modeling_canine.py b/tests/models/canine/test_modeling_canine.py index d612a02bf4..6e6d7ce383 100644 --- a/tests/models/canine/test_modeling_canine.py +++ b/tests/models/canine/test_modeling_canine.py @@ -499,6 +499,12 @@ class CanineModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): # ViT does not use inputs_embeds pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip("CANINE does not have a get_input_embeddings() method.") def test_model_common_attributes(self): pass diff --git a/tests/models/chinese_clip/test_modeling_chinese_clip.py b/tests/models/chinese_clip/test_modeling_chinese_clip.py index 57f532da86..cf2668f4d8 100644 --- a/tests/models/chinese_clip/test_modeling_chinese_clip.py +++ b/tests/models/chinese_clip/test_modeling_chinese_clip.py @@ -395,6 +395,12 @@ class ChineseCLIPTextModelTest(ModelTesterMixin, unittest.TestCase): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="ChineseCLIPTextModel has no base class and is not available in MODEL_MAPPING") def test_save_load_fast_init_from_base(self): pass @@ -469,6 +475,12 @@ class ChineseCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): def test_save_load_fast_init_to_base(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in CHINESE_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index d16241ab2f..82592d8452 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -227,6 +227,12 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="CLIPVisionModel has no base class and is not available in MODEL_MAPPING") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/clipseg/test_modeling_clipseg.py b/tests/models/clipseg/test_modeling_clipseg.py index b54861d8d8..387a2e1c8f 100644 --- a/tests/models/clipseg/test_modeling_clipseg.py +++ b/tests/models/clipseg/test_modeling_clipseg.py @@ -202,6 +202,12 @@ class CLIPSegVisionModelTest(ModelTesterMixin, unittest.TestCase): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="CLIPSegVisionModel has no base class and is not available in MODEL_MAPPING") def test_save_load_fast_init_from_base(self): pass @@ -448,6 +454,12 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) def test_hidden_states_output(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="Inputs_embeds is tested in individual model tests") def test_inputs_embeds(self): pass diff --git a/tests/models/data2vec/test_modeling_data2vec_vision.py b/tests/models/data2vec/test_modeling_data2vec_vision.py index b4c391fea1..90786de249 100644 --- a/tests/models/data2vec/test_modeling_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_data2vec_vision.py @@ -310,6 +310,12 @@ class Data2VecVisionModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/dpt/test_modeling_dpt.py b/tests/models/dpt/test_modeling_dpt.py index 76790ee795..5889653991 100644 --- a/tests/models/dpt/test_modeling_dpt.py +++ b/tests/models/dpt/test_modeling_dpt.py @@ -182,6 +182,12 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_inputs_embeds(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_model_common_attributes(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/dpt/test_modeling_dpt_hybrid.py b/tests/models/dpt/test_modeling_dpt_hybrid.py index 6d4a75c80d..04ba8c0289 100644 --- a/tests/models/dpt/test_modeling_dpt_hybrid.py +++ b/tests/models/dpt/test_modeling_dpt_hybrid.py @@ -196,6 +196,12 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_inputs_embeds(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_model_common_attributes(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py index 2544b7ee93..b6d71f33a6 100644 --- a/tests/models/flava/test_modeling_flava.py +++ b/tests/models/flava/test_modeling_flava.py @@ -185,6 +185,12 @@ class FlavaImageModelTest(ModelTesterMixin, unittest.TestCase): # FLAVA does not use inputs_embeds pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_model_common_attributes(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -462,6 +468,12 @@ class FlavaTextModelTest(ModelTesterMixin, unittest.TestCase): # FLAVA does not use inputs_embeds pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + # skip this test as FlavaTextModel has no base class and is # not available in MODEL_MAPPING def test_save_load_fast_init_from_base(self): @@ -624,6 +636,12 @@ class FlavaMultimodalModelTest(ModelTesterMixin, unittest.TestCase): def test_save_load_fast_init_to_base(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: @@ -731,6 +749,12 @@ class FlavaImageCodebookTest(ModelTesterMixin, unittest.TestCase): def test_save_load_fast_init_to_base(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: @@ -1156,6 +1180,12 @@ class FlavaForPreTrainingTest(FlavaModelTest): class_for_tester = FlavaForPreTrainingTester test_torchscript = False + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/fnet/test_modeling_fnet.py b/tests/models/fnet/test_modeling_fnet.py index e7e592d5b6..9682184273 100644 --- a/tests/models/fnet/test_modeling_fnet.py +++ b/tests/models/fnet/test_modeling_fnet.py @@ -444,6 +444,12 @@ class FNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in FNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 65542b4954..620bb30b26 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -562,6 +562,12 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt2_weight_initialization(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_batch_generation(self): model = GPT2LMHeadModel.from_pretrained("gpt2") diff --git a/tests/models/graphormer/test_modeling_graphormer.py b/tests/models/graphormer/test_modeling_graphormer.py index e874ebf0f4..f1c63729e0 100644 --- a/tests/models/graphormer/test_modeling_graphormer.py +++ b/tests/models/graphormer/test_modeling_graphormer.py @@ -356,6 +356,12 @@ class GraphormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa def test_feed_forward_chunking(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="Graphormer does not share input and output embeddings") def test_model_common_attributes(self): pass diff --git a/tests/models/imagegpt/test_modeling_imagegpt.py b/tests/models/imagegpt/test_modeling_imagegpt.py index 27d83f3eb8..1f4ea02f8d 100644 --- a/tests/models/imagegpt/test_modeling_imagegpt.py +++ b/tests/models/imagegpt/test_modeling_imagegpt.py @@ -304,6 +304,12 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM def test_config(self): self.config_tester.run_common_tests() + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_imagegpt_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_imagegpt_model(*config_and_inputs) diff --git a/tests/models/informer/test_modeling_informer.py b/tests/models/informer/test_modeling_informer.py index f3c8539d84..2202d62242 100644 --- a/tests/models/informer/test_modeling_informer.py +++ b/tests/models/informer/test_modeling_informer.py @@ -216,6 +216,12 @@ class InformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): model = model_class(config) diff --git a/tests/models/layoutlm/test_modeling_layoutlm.py b/tests/models/layoutlm/test_modeling_layoutlm.py index 0535fbf4e1..b88d0c4b50 100644 --- a/tests/models/layoutlm/test_modeling_layoutlm.py +++ b/tests/models/layoutlm/test_modeling_layoutlm.py @@ -279,6 +279,12 @@ class LayoutLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def prepare_layoutlm_batch_inputs(): # Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on: diff --git a/tests/models/lilt/test_modeling_lilt.py b/tests/models/lilt/test_modeling_lilt.py index 1bb92300c3..4032504b8b 100644 --- a/tests/models/lilt/test_modeling_lilt.py +++ b/tests/models/lilt/test_modeling_lilt.py @@ -275,6 +275,12 @@ class LiltModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in LILT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/luke/test_modeling_luke.py b/tests/models/luke/test_modeling_luke.py index 35bdb6b6d5..4e1ef3d173 100644 --- a/tests/models/luke/test_modeling_luke.py +++ b/tests/models/luke/test_modeling_luke.py @@ -697,6 +697,12 @@ class LukeModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in LUKE_PRETRAINED_MODEL_ARCHIVE_LIST: diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index 6cbcd55d3f..933383d292 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -263,6 +263,12 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix def test_config(self): self.config_tester.run_common_tests() + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_save_load_strict(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs() for model_class in self.all_model_classes: diff --git a/tests/models/owlvit/test_modeling_owlvit.py b/tests/models/owlvit/test_modeling_owlvit.py index acf078ffe8..83fb86ba0e 100644 --- a/tests/models/owlvit/test_modeling_owlvit.py +++ b/tests/models/owlvit/test_modeling_owlvit.py @@ -155,6 +155,12 @@ class OwlViTVisionModelTest(ModelTesterMixin, unittest.TestCase): def test_inputs_embeds(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_model_common_attributes(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -633,6 +639,12 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def _create_and_check_torchscript(self, config, inputs_dict): if not self.test_torchscript: return diff --git a/tests/models/pegasus/test_modeling_pegasus.py b/tests/models/pegasus/test_modeling_pegasus.py index bde7477f94..1f409d1b00 100644 --- a/tests/models/pegasus/test_modeling_pegasus.py +++ b/tests/models/pegasus/test_modeling_pegasus.py @@ -280,6 +280,12 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_generate_fp16(self): config, input_dict = self.model_tester.prepare_config_and_inputs() input_ids = input_dict["input_ids"] diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index 8ec023676d..1eba4cb10c 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -332,6 +332,12 @@ class Pix2StructTextModelTest(ModelTesterMixin, unittest.TestCase): def test_training(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="Training is tested directly on `Pix2StructTextImageModelTest`") def test_training_gradient_checkpointing(self): pass diff --git a/tests/models/regnet/test_modeling_regnet.py b/tests/models/regnet/test_modeling_regnet.py index e7c33699fd..9b26084528 100644 --- a/tests/models/regnet/test_modeling_regnet.py +++ b/tests/models/regnet/test_modeling_regnet.py @@ -161,6 +161,12 @@ class RegNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_model_common_attributes(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/roformer/test_modeling_roformer.py b/tests/models/roformer/test_modeling_roformer.py index 357e126a04..6d54b7c128 100644 --- a/tests/models/roformer/test_modeling_roformer.py +++ b/tests/models/roformer/test_modeling_roformer.py @@ -452,6 +452,12 @@ class RoFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_model_as_decoder_with_default_input_mask(self): # This regression test was failing with PyTorch < 1.3 ( diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index a0f39a4013..8a47721386 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -421,6 +421,12 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index 16ad704fd5..1524ce24d2 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -324,6 +324,12 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_generate_fp16(self): config, input_dict = self.model_tester.prepare_config_and_inputs() input_features = input_dict["input_features"] diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index f8730d8993..4ff4554fbc 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -613,6 +613,12 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_beam_sample_generate_dict_output(self): r""" diff --git a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py index 42319a1dd0..44962267fe 100644 --- a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py +++ b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py @@ -200,6 +200,12 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, PipelineTesterMixin, unit def test_config(self): self.config_tester.run_common_tests() + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_save_load_strict(self): config, _ = self.model_tester.prepare_config_and_inputs() for model_class in self.all_model_classes: diff --git a/tests/models/van/test_modeling_van.py b/tests/models/van/test_modeling_van.py index 49df30a828..7ec941dbc8 100644 --- a/tests/models/van/test_modeling_van.py +++ b/tests/models/van/test_modeling_van.py @@ -243,6 +243,12 @@ class VanModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model = VanModel.from_pretrained(model_name) self.assertIsNotNone(model) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/vilt/test_modeling_vilt.py b/tests/models/vilt/test_modeling_vilt.py index 772091d5b9..17447acf68 100644 --- a/tests/models/vilt/test_modeling_vilt.py +++ b/tests/models/vilt/test_modeling_vilt.py @@ -340,6 +340,12 @@ class ViltModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_model_outputs_equivalence(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True diff --git a/tests/models/visual_bert/test_modeling_visual_bert.py b/tests/models/visual_bert/test_modeling_visual_bert.py index cf48fd7ffb..5dae4ebe1f 100644 --- a/tests/models/visual_bert/test_modeling_visual_bert.py +++ b/tests/models/visual_bert/test_modeling_visual_bert.py @@ -549,6 +549,12 @@ class VisualBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa config_and_inputs = self.model_tester.prepare_config_and_inputs_for_flickr() self.model_tester.create_and_check_for_flickr(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index c58e2e9480..77c36bef8b 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -208,6 +208,12 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_pretraining(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + # overwrite from common since ViTMAEForPretraining has random masking, we need to fix the noise # to generate masks during test def check_pt_tf_models(self, tf_model, pt_model, pt_inputs_dict): diff --git a/tests/models/x_clip/test_modeling_x_clip.py b/tests/models/x_clip/test_modeling_x_clip.py index 2efece44ca..7fd65d871d 100644 --- a/tests/models/x_clip/test_modeling_x_clip.py +++ b/tests/models/x_clip/test_modeling_x_clip.py @@ -202,6 +202,12 @@ class XCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): def test_save_load_fast_init_to_base(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in XCLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 07a8b16bfe..7c02141f05 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import collections import copy import gc @@ -549,6 +548,41 @@ class ModelTesterMixin: loss = model(**inputs).loss loss.backward() + @slow + @require_torch_gpu + def test_training_gradient_checkpointing_autocast(self): + if not self.model_tester.is_training: + return + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + + if ( + model_class.__name__ + in [*get_values(MODEL_MAPPING_NAMES), *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] + or not model_class.supports_gradient_checkpointing + ): + continue + model = model_class(config) + model.to(torch_device) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + model.gradient_checkpointing_enable() + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + with torch.cuda.amp.autocast(True, dtype=torch.float16): + output = model(**inputs)[0] + loss = output.mean() + + loss.backward() + optimizer.step() + + for n, param in model.named_parameters(): + self.assertTrue(param.grad is not None, f"None gradient in param {n}") + def test_attention_outputs(self): if not self.has_attentions: self.skipTest(reason="Model does not output attentions")