From 2b4734bd4907d54a14f992c42d079af8dfffe6b0 Mon Sep 17 00:00:00 2001 From: efsotr <104755879+efsotr@users.noreply.github.com> Date: Mon, 31 Mar 2025 16:53:02 +0800 Subject: [PATCH] Support passing flash_attn_kwargs when gradient_checkpointing is enabled (#37037) * support passing flash_attn_kwargs when gradient_checkpointing is enabled * make modeling_deepspeek_v3.py consistent with modular_deepseek_v3.py --- examples/modular-transformers/modeling_dummy.py | 3 ++- examples/modular-transformers/modeling_multimodal1.py | 3 ++- src/transformers/models/aria/modeling_aria.py | 3 ++- src/transformers/models/cohere/modeling_cohere.py | 3 ++- src/transformers/models/cohere2/modeling_cohere2.py | 3 ++- src/transformers/models/cohere2/modular_cohere2.py | 3 ++- src/transformers/models/deepseek_v3/modeling_deepseek_v3.py | 3 ++- src/transformers/models/diffllama/modeling_diffllama.py | 3 ++- src/transformers/models/emu3/modeling_emu3.py | 4 ++-- src/transformers/models/gemma2/modeling_gemma2.py | 3 ++- src/transformers/models/gemma2/modular_gemma2.py | 3 ++- src/transformers/models/gemma3/modeling_gemma3.py | 3 ++- src/transformers/models/gemma3/modular_gemma3.py | 3 ++- src/transformers/models/glm/modeling_glm.py | 3 ++- src/transformers/models/granite/modeling_granite.py | 3 ++- src/transformers/models/granite/modular_granite.py | 3 ++- src/transformers/models/helium/modeling_helium.py | 3 ++- src/transformers/models/llama/modeling_llama.py | 3 ++- src/transformers/models/mistral/modeling_mistral.py | 3 ++- src/transformers/models/mixtral/modeling_mixtral.py | 3 ++- src/transformers/models/mixtral/modular_mixtral.py | 3 ++- src/transformers/models/moonshine/modeling_moonshine.py | 3 ++- src/transformers/models/moonshine/modular_moonshine.py | 3 ++- src/transformers/models/olmo/modeling_olmo.py | 3 ++- src/transformers/models/olmo2/modeling_olmo2.py | 3 ++- src/transformers/models/phi/modeling_phi.py | 3 ++- src/transformers/models/phi/modular_phi.py | 3 ++- src/transformers/models/phi3/modeling_phi3.py | 3 ++- src/transformers/models/qwen2/modeling_qwen2.py | 3 ++- 29 files changed, 58 insertions(+), 30 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 7ade13e977..98a72a3e65 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -4,6 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_dummy.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -544,7 +545,7 @@ class DummyModel(DummyPreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index ee649f9286..91d226d12b 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -4,6 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_multimodal1.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -544,7 +545,7 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 5e20264517..d87b8ec0c5 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass +from functools import partial from typing import Callable, List, Optional, Tuple, Union from ...activations import ACT2FN @@ -963,7 +964,7 @@ class AriaTextModel(AriaTextPreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 24fae66f05..60adcf89af 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -27,6 +27,7 @@ # This file is based on the LLama model definition file in transformers +from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -613,7 +614,7 @@ class CohereModel(CoherePreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 0f21f7045b..be51a992a8 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -19,6 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -634,7 +635,7 @@ class Cohere2Model(Cohere2PreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, position_embeddings, causal_mask, diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 154330b1c9..ce092545f1 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -533,7 +534,7 @@ class Cohere2Model(Gemma2Model): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, position_embeddings, causal_mask, diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index cab1e41cd7..24870d2f69 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -5,6 +5,7 @@ # modular_deepseek_v3.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -759,7 +760,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 8d13b17872..c86fffad7a 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from functools import partial from typing import Optional, Tuple, Union import torch @@ -852,7 +853,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 82dfc23daf..43996b4132 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -21,7 +21,7 @@ # limitations under the License. import math -from functools import cached_property +from functools import cached_property, partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -1439,7 +1439,7 @@ class Emu3TextModel(Emu3PreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 0c6b8188fb..6b23f26208 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -19,6 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -645,7 +646,7 @@ class Gemma2Model(Gemma2PreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, position_embeddings, causal_mask, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index ab567c61d0..06f09fab10 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -491,7 +492,7 @@ class Gemma2Model(GemmaModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, position_embeddings, causal_mask, diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 92d2d36caa..f5700f060d 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -22,6 +22,7 @@ import copy from collections.abc import Callable from dataclasses import dataclass +from functools import partial from typing import List, Optional, Tuple, Union import torch @@ -732,7 +733,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, position_embeddings_global, position_embeddings_local, diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index e9baaf1c52..f869a06530 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -16,6 +16,7 @@ import copy from collections.abc import Callable from dataclasses import dataclass +from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -662,7 +663,7 @@ class Gemma3TextModel(Gemma2Model): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, position_embeddings_global, position_embeddings_local, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 28156d404c..716c97de3f 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -19,6 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -594,7 +595,7 @@ class GlmModel(GlmPreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index d564e08580..f25cbe0dac 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -19,6 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -593,7 +594,7 @@ class GraniteModel(GranitePreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index f23ae4a673..3781ea47ad 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import List, Optional, Tuple, Union import torch @@ -185,7 +186,7 @@ class GraniteModel(LlamaModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 3164986642..be55e4ebf9 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -20,6 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -581,7 +582,7 @@ class HeliumModel(HeliumPreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 513e65204f..78cf7a930a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -17,6 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -583,7 +584,7 @@ class LlamaModel(LlamaPreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index bcb294712c..c7b9a4523d 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -4,6 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_mistral.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -548,7 +549,7 @@ class MistralModel(MistralPreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 6b00960f38..13e14a755d 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -24,6 +24,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -672,7 +673,7 @@ class MixtralModel(MixtralPreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index b32a8d7987..c7fa30376b 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -19,6 +19,7 @@ # limitations under the License. """PyTorch Mixtral model.""" +from functools import partial from typing import List, Optional, Tuple, Union import torch @@ -400,7 +401,7 @@ class MixtralModel(MistralModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 04cf4d5a2c..78438151b8 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, Optional, Tuple, Union import numpy as np @@ -936,7 +937,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, encoder_hidden_states, diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index db071b526e..f1fdd7c58d 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -832,7 +833,7 @@ class MoonshineDecoder(LlamaModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, encoder_hidden_states, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index bd8a88af33..23acd45eb2 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -4,6 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_olmo.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -559,7 +560,7 @@ class OlmoModel(OlmoPreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index dfdaab9a2b..9af94ae0aa 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -4,6 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_olmo2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -560,7 +561,7 @@ class Olmo2Model(Olmo2PreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index f071ad043e..a5a008a6f1 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -4,6 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_phi.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -553,7 +554,7 @@ class PhiModel(PhiPreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 1b98d939bf..4dcf74d741 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -243,7 +244,7 @@ class PhiModel(LlamaModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 8cfd65a6f2..bd781216da 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -20,6 +20,7 @@ # limitations under the License. +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -623,7 +624,7 @@ class Phi3Model(Phi3PreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index c266ec374c..e009b6f693 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -4,6 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_qwen2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -561,7 +562,7 @@ class Qwen2Model(Qwen2PreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids,