From 9167fadab9bbe7ecb758e5d2e6b70ac4d0781e24 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 22 Apr 2025 11:33:31 +0100 Subject: [PATCH] Introduce GradientCheckpointingLayer (#37223) * GradientCheckpointingLayer * trigger * Move GC layer to a separate file * Update import * Expose and document GC layer * Fix dummy * Apply to llama-based models * Update modulars * Update a few more models for consistency * Update glm4 * Update Janus --- docs/source/en/internal/modeling_utils.md | 4 + src/transformers/__init__.py | 2 + src/transformers/modeling_layers.py | 48 ++++++++++++ src/transformers/models/aria/modeling_aria.py | 39 ++++------ .../models/cohere/modeling_cohere.py | 39 ++++------ .../models/cohere/modular_cohere.py | 3 +- .../models/cohere2/modeling_cohere2.py | 36 +++------ .../models/cohere2/modular_cohere2.py | 33 +++----- .../deepseek_v3/modeling_deepseek_v3.py | 39 ++++------ .../models/diffllama/modeling_diffllama.py | 39 ++++------ src/transformers/models/emu3/modeling_emu3.py | 40 ++++------ .../models/gemma/modeling_gemma.py | 36 +++------ .../models/gemma/modular_gemma.py | 33 +++----- src/transformers/models/glm/modeling_glm.py | 39 ++++------ src/transformers/models/glm4/modeling_glm4.py | 39 ++++------ src/transformers/models/glm4/modular_glm4.py | 4 +- .../models/granite/modeling_granite.py | 39 ++++------ .../models/granite/modular_granite.py | 36 +++------ .../models/helium/modeling_helium.py | 39 ++++------ .../models/janus/modeling_janus.py | 22 ++---- .../models/llama/modeling_llama.py | 39 ++++------ .../models/mistral/modeling_mistral.py | 39 ++++------ .../models/moonshine/modeling_moonshine.py | 75 ++++++------------- .../models/moonshine/modular_moonshine.py | 73 ++++++------------ src/transformers/models/olmo/modeling_olmo.py | 39 ++++------ .../models/olmo2/modeling_olmo2.py | 39 ++++------ src/transformers/models/phi3/modeling_phi3.py | 39 ++++------ .../modeling_phi4_multimodal.py | 68 ++++++----------- .../modular_phi4_multimodal.py | 44 +++-------- .../models/qwen2/modeling_qwen2.py | 39 ++++------ .../models/qwen3/modeling_qwen3.py | 39 ++++------ .../models/siglip/modeling_siglip.py | 22 ++---- .../models/siglip2/modeling_siglip2.py | 22 ++---- .../models/starcoder2/modeling_starcoder2.py | 3 +- src/transformers/utils/dummy_pt_objects.py | 7 ++ 35 files changed, 435 insertions(+), 761 deletions(-) create mode 100644 src/transformers/modeling_layers.py diff --git a/docs/source/en/internal/modeling_utils.md b/docs/source/en/internal/modeling_utils.md index fe6f961da9..1c7d16ad06 100644 --- a/docs/source/en/internal/modeling_utils.md +++ b/docs/source/en/internal/modeling_utils.md @@ -20,6 +20,10 @@ This page lists all the custom layers used by the library, as well as the utilit Most of those are only useful if you are studying the code of the models in the library. +## Layers + +[[autodoc]] GradientCheckpointingLayer + ## Attention Functions [[autodoc]] AttentionInterface diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b7ba86b64f..9d051dd2db 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -438,6 +438,7 @@ else: ] _import_structure["modeling_flash_attention_utils"] = [] + _import_structure["modeling_layers"] = ["GradientCheckpointingLayer"] _import_structure["modeling_outputs"] = [] _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"] _import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"] @@ -911,6 +912,7 @@ if TYPE_CHECKING: from .model_debugging_utils import ( model_addition_debugger_context, ) + from .modeling_layers import GradientCheckpointingLayer from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from .modeling_utils import AttentionInterface, PreTrainedModel diff --git a/src/transformers/modeling_layers.py b/src/transformers/modeling_layers.py new file mode 100644 index 0000000000..57be2d8e0d --- /dev/null +++ b/src/transformers/modeling_layers.py @@ -0,0 +1,48 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import torch.nn as nn + + +class GradientCheckpointingLayer(nn.Module): + """Base class for layers with gradient checkpointing. + + This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled + (`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is + enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`. + + Important: + + When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states) + must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients. + + Example: + + ```python + >>> # Correct - hidden_states passed as positional arg + >>> out = self.layer(hidden_states, attention_mask=attention_mask) + + >>> # Incorrect - hidden_states passed as keyword arg + >>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask) + ``` + """ + + gradient_checkpointing = False + + def __call__(self, *args, **kwargs): + if self.gradient_checkpointing and self.training: + return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args) + return super().__call__(*args, **kwargs) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 4dc9df7a51..1b9892c94f 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from functools import partial from typing import Callable, List, Optional, Tuple, Union from ...activations import ACT2FN @@ -28,6 +27,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -590,7 +590,7 @@ class AriaTextAttention(nn.Module): return attn_output, attn_weights -class AriaTextDecoderLayer(nn.Module): +class AriaTextDecoderLayer(GradientCheckpointingLayer): """ Aria Text Decoder Layer. @@ -940,30 +940,17 @@ class AriaTextModel(AriaTextPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index a250d47809..9d13b30486 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -27,7 +27,6 @@ # This file is based on the LLama model definition file in transformers -from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -38,6 +37,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -301,7 +301,7 @@ class CohereAttention(nn.Module): return attn_output, attn_weights -class CohereDecoderLayer(nn.Module): +class CohereDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: CohereConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -589,30 +589,17 @@ class CohereModel(CoherePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index c9f2d8ff27..c4acb43494 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -30,6 +30,7 @@ from torch import nn from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -209,7 +210,7 @@ class CohereAttention(LlamaAttention): return attn_output, attn_weights -class CohereDecoderLayer(nn.Module): +class CohereDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: CohereConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 8e7c44fcad..cc790124cc 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -19,7 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -29,6 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -290,7 +290,7 @@ class Cohere2MLP(nn.Module): return down_proj -class Cohere2DecoderLayer(nn.Module): +class Cohere2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Cohere2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -612,28 +612,16 @@ class Cohere2Model(Cohere2PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - position_embeddings, - causal_mask, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=causal_mask, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 85a8d04a50..e811aabedb 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Tuple import torch @@ -526,28 +525,16 @@ class Cohere2Model(Gemma2Model): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - position_embeddings, - causal_mask, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=causal_mask, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 546c7fed3d..511c401f7e 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -5,7 +5,6 @@ # modular_deepseek_v3.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math -from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -18,6 +17,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -454,7 +454,7 @@ class DeepseekV3Attention(nn.Module): return attn_output, attn_weights -class DeepseekV3DecoderLayer(nn.Module): +class DeepseekV3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DeepseekV3Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -734,30 +734,17 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 1959fac86e..b2d736cb4f 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -22,7 +22,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from functools import partial from typing import Optional, Tuple, Union import torch @@ -38,6 +37,7 @@ from ...modeling_flash_attention_utils import ( _flash_attention_forward, flash_attn_supports_top_left_mask, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -526,7 +526,7 @@ DIFFLLAMA_ATTENTION_CLASSES = { } -class DiffLlamaDecoderLayer(nn.Module): +class DiffLlamaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DiffLlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -837,30 +837,17 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index f74b0cacb0..375b1beb23 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -21,7 +21,7 @@ # limitations under the License. import math -from functools import cached_property, partial +from functools import cached_property from typing import Callable, List, Optional, Tuple, Union import torch @@ -34,6 +34,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -248,7 +249,7 @@ class Emu3Attention(nn.Module): return attn_output, attn_weights -class Emu3DecoderLayer(nn.Module): +class Emu3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Emu3Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -1422,30 +1423,17 @@ class Emu3TextModel(Emu3PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index df66bc36e4..481fa70eb2 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -29,6 +29,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -282,7 +283,7 @@ class GemmaAttention(nn.Module): return attn_output, attn_weights -class GemmaDecoderLayer(nn.Module): +class GemmaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: GemmaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -557,29 +558,16 @@ class GemmaModel(GemmaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index fa6af70ecf..ec16f5bba8 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -434,29 +434,16 @@ class GemmaModel(LlamaModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index a12057cbb2..9a688e0b8e 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -19,7 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -31,6 +30,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -300,7 +300,7 @@ class GlmRotaryEmbedding(nn.Module): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class GlmDecoderLayer(nn.Module): +class GlmDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: GlmConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -572,30 +572,17 @@ class GlmModel(GlmPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 8eb015ac00..faf188387d 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -19,7 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -31,6 +30,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -83,7 +83,7 @@ class Glm4MLP(nn.Module): return self.down_proj(up_states) -class Glm4DecoderLayer(nn.Module): +class Glm4DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Glm4Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -580,30 +580,17 @@ class Glm4Model(Glm4PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/glm4/modular_glm4.py b/src/transformers/models/glm4/modular_glm4.py index 493868f3b3..2a52e5b670 100644 --- a/src/transformers/models/glm4/modular_glm4.py +++ b/src/transformers/models/glm4/modular_glm4.py @@ -15,11 +15,11 @@ # limitations under the License. from typing import Optional, Tuple, Union -import torch.nn as nn import torch.utils.checkpoint from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import CausalLMOutputWithPast from ...processing_utils import Unpack from ...utils import LossKwargs, logging @@ -43,7 +43,7 @@ class Glm4MLP(Phi3MLP): pass -class Glm4DecoderLayer(nn.Module): +class Glm4DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Glm4Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 1db20f8624..4538a323c7 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -19,7 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -31,6 +30,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -243,7 +243,7 @@ class GraniteMLP(nn.Module): return down_proj -class GraniteDecoderLayer(nn.Module): +class GraniteDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: GraniteConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -572,30 +572,17 @@ class GraniteModel(GranitePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index dd88957fdb..1d0ff532d8 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import List, Optional, Tuple, Union import torch @@ -192,30 +191,17 @@ class GraniteModel(LlamaModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 16fcdc3a77..9f7b4b8793 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -20,7 +20,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -31,6 +30,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -285,7 +285,7 @@ class HeliumAttention(nn.Module): return attn_output, attn_weights -class HeliumDecoderLayer(nn.Module): +class HeliumDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size @@ -557,30 +557,17 @@ class HeliumModel(HeliumPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 2a0af557f9..4fa538ee7c 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -31,6 +31,7 @@ from ...cache_utils import Cache from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList from ...generation.utils import GenerateDecoderOnlyOutput from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -429,7 +430,7 @@ class JanusVisionMLP(nn.Module): return hidden_states -class JanusVisionEncoderLayer(nn.Module): +class JanusVisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: JanusVisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -536,19 +537,12 @@ class JanusVisionEncoder(nn.Module): for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a8bebd2a36..f6d5471ce3 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -17,7 +17,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -29,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -290,7 +290,7 @@ class LlamaAttention(nn.Module): return attn_output, attn_weights -class LlamaDecoderLayer(nn.Module): +class LlamaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -562,30 +562,17 @@ class LlamaModel(LlamaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 20e8aca622..ebd0fe6c01 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -4,7 +4,6 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_mistral.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -16,6 +15,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -228,7 +228,7 @@ class MistralRMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class MistralDecoderLayer(nn.Module): +class MistralDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -532,30 +532,17 @@ class MistralModel(MistralPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 97453d6614..6d10364a6e 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Tuple, Union import numpy as np @@ -34,6 +33,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPast, @@ -351,7 +351,7 @@ class MoonshineRotaryEmbedding(nn.Module): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class MoonshineEncoderLayer(nn.Module): +class MoonshineEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: MoonshineConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -410,7 +410,7 @@ class MoonshineEncoderLayer(nn.Module): return outputs -class MoonshineDecoderLayer(nn.Module): +class MoonshineDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size @@ -668,27 +668,14 @@ class MoonshineEncoder(MoonshinePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - None, - output_attentions, - False, - None, - position_embeddings, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] @@ -912,33 +899,19 @@ class MoonshineDecoder(MoonshinePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - encoder_hidden_states, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_attention_mask=encoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index d76941c8d6..8864c0b41b 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -27,6 +26,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPast, @@ -428,7 +428,7 @@ class MoonshineEncoderLayer(LlamaDecoderLayer): self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False) -class MoonshineDecoderLayer(nn.Module): +class MoonshineDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size @@ -686,27 +686,14 @@ class MoonshineEncoder(MoonshinePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - None, - output_attentions, - False, - None, - position_embeddings, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] @@ -832,33 +819,19 @@ class MoonshineDecoder(LlamaModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - encoder_hidden_states, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_attention_mask=encoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index aa7d6e7445..6de200e662 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -4,7 +4,6 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_olmo.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -16,6 +15,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -229,7 +229,7 @@ class OlmoAttention(nn.Module): return attn_output, attn_weights -class OlmoDecoderLayer(nn.Module): +class OlmoDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: OlmoConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -532,30 +532,17 @@ class OlmoModel(OlmoPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 999b2ded05..31e6805cfd 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -4,7 +4,6 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_olmo2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -16,6 +15,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -233,7 +233,7 @@ class Olmo2MLP(nn.Module): return down_proj -class Olmo2DecoderLayer(nn.Module): +class Olmo2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Olmo2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -538,30 +538,17 @@ class Olmo2Model(Olmo2PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 4955857247..aaf5332c71 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -20,7 +20,6 @@ # limitations under the License. -from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -32,6 +31,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -257,7 +257,7 @@ class Phi3RMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Phi3DecoderLayer(nn.Module): +class Phi3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Phi3Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -587,30 +587,17 @@ class Phi3Model(Phi3PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 706fb6642b..3b9979edf3 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -36,6 +36,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPast, @@ -154,7 +155,7 @@ class Phi4MultimodalVisionAttention(nn.Module): return attn_output, attn_weights -class Phi4MultimodalVisionEncoderLayer(nn.Module): +class Phi4MultimodalVisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -262,19 +263,12 @@ class Phi4MultimodalVisionEncoder(nn.Module): for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1213,14 +1207,7 @@ class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel): attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias for layer in self.encoders: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - ) - else: - hidden_states = layer(hidden_states, attention_mask) + hidden_states = layer(hidden_states, attention_mask) if unfolded: embed_dim = hidden_states.shape[-1] @@ -1483,7 +1470,7 @@ class Phi4MultimodalAttention(nn.Module): return attn_output, attn_weights -class Phi4MultimodalDecoderLayer(nn.Module): +class Phi4MultimodalDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Phi4MultimodalConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -1885,30 +1872,17 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index d269b06037..b9d6bc0daf 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -1265,14 +1265,7 @@ class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel): attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias for layer in self.encoders: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - ) - else: - hidden_states = layer(hidden_states, attention_mask) + hidden_states = layer(hidden_states, attention_mask) if unfolded: embed_dim = hidden_states.shape[-1] @@ -1655,30 +1648,17 @@ class Phi4MultimodalModel(Phi3Model, nn.Module): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 2d24e228c2..46fb6326ad 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -4,7 +4,6 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_qwen2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -16,6 +15,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -236,7 +236,7 @@ class Qwen2RMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Qwen2DecoderLayer(nn.Module): +class Qwen2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -545,30 +545,17 @@ class Qwen2Model(Qwen2PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 8416478e5a..2a869fe630 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -31,6 +30,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -261,7 +261,7 @@ class Qwen3Attention(nn.Module): return attn_output, attn_weights -class Qwen3DecoderLayer(nn.Module): +class Qwen3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -572,30 +572,17 @@ class Qwen3Model(Qwen3PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 6c488262a5..f77248970c 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -28,6 +28,7 @@ from torch.nn.init import _calculate_fan_in_and_fan_out from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( @@ -465,7 +466,7 @@ class SiglipMLP(nn.Module): return hidden_states -class SiglipEncoderLayer(nn.Module): +class SiglipEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]): super().__init__() self.embed_dim = config.hidden_size @@ -743,19 +744,12 @@ class SiglipEncoder(nn.Module): for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/siglip2/modeling_siglip2.py b/src/transformers/models/siglip2/modeling_siglip2.py index 612e149d54..0999ff6984 100644 --- a/src/transformers/models/siglip2/modeling_siglip2.py +++ b/src/transformers/models/siglip2/modeling_siglip2.py @@ -32,6 +32,7 @@ from torch.nn.init import _calculate_fan_in_and_fan_out from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( @@ -356,7 +357,7 @@ class Siglip2MLP(nn.Module): return hidden_states -class Siglip2EncoderLayer(nn.Module): +class Siglip2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]): super().__init__() self.embed_dim = config.hidden_size @@ -462,19 +463,12 @@ class Siglip2Encoder(nn.Module): for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 0299cab66f..84e420607b 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -34,6 +34,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -230,7 +231,7 @@ class Starcoder2Attention(nn.Module): return attn_output, attn_weights -class Starcoder2DecoderLayer(nn.Module): +class Starcoder2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Starcoder2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index be4e47b9fc..b7291969eb 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -542,6 +542,13 @@ def model_addition_debugger_context(*args, **kwargs): requires_backends(model_addition_debugger_context, ["torch"]) +class GradientCheckpointingLayer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + ROPE_INIT_FUNCTIONS = None