From e5ac23081ec4021818a21d7442d396f31de8c30c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 18 Apr 2025 14:55:43 +0100 Subject: [PATCH] =?UTF-8?q?[Gemma3]=20compile=20=E2=9C=A8=20=20(#37447)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transformers/cache_utils.py | 6 +-- .../models/cohere2/modeling_cohere2.py | 33 +++++-------- .../models/cohere2/modular_cohere2.py | 46 ++++++++----------- .../models/gemma2/modeling_gemma2.py | 31 +++++-------- .../models/gemma2/modular_gemma2.py | 41 +++++++---------- .../models/gemma3/modeling_gemma3.py | 31 +++++-------- .../models/gemma3/modular_gemma3.py | 36 ++++++--------- tests/generation/test_utils.py | 6 --- tests/models/gemma2/test_modeling_gemma2.py | 4 ++ tests/models/gemma3/test_modeling_gemma3.py | 4 ++ 10 files changed, 90 insertions(+), 148 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 29a30f3ab7..89a017bfdb 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1654,9 +1654,7 @@ class HybridCache(Cache): ``` """ - # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert - # ALL changes from the PR that commented the line below when reactivating it. - # is_compileable = True + is_compileable = True def __init__( self, @@ -1858,8 +1856,6 @@ class HybridChunkedCache(Cache): ``` """ - # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert - # ALL changes from the PR that commented the line below when reactivating it. is_compileable = True def __init__( diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index a7189a1a21..1af14d021c 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -42,6 +42,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_cohere2 import Cohere2Config @@ -300,6 +301,7 @@ class Cohere2DecoderLayer(nn.Module): self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0 self.sliding_window = config.sliding_window + @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( self, hidden_states: torch.Tensor, @@ -309,7 +311,6 @@ class Cohere2DecoderLayer(nn.Module): output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - last_cache_position: int = 0, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -330,7 +331,6 @@ class Cohere2DecoderLayer(nn.Module): (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence - last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing """ if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding @@ -349,11 +349,16 @@ class Cohere2DecoderLayer(nn.Module): ) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) # In case we are beyond the sliding window, we need to correctly offset the mask slicing - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo - offset = last_cache_position - effective_seq_len + offset = cache_position[-1] - effective_seq_len + 1 # Should only be used when beyond the sliding window (i.e. offset > 0) offset = max(0, offset) - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] + # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, + # but without data-dependent slicing (i.e. torch.compile friendly) + mask_indexes = torch.arange( + min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device + ) + mask_indexes += offset + attention_mask = attention_mask[:, :, :, mask_indexes] residual = hidden_states @@ -539,6 +544,7 @@ class Cohere2Model(Cohere2PreTrainedModel): @can_return_tuple @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING) + @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -550,7 +556,6 @@ class Cohere2Model(Cohere2PreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -590,16 +595,6 @@ class Cohere2Model(Cohere2PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - if last_cache_position is None: - last_cache_position = 0 - if attention_mask is not None: - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) - last_cache_position = ( - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() - ) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) @@ -627,7 +622,6 @@ class Cohere2Model(Cohere2PreTrainedModel): output_attentions, use_cache, cache_position, - last_cache_position, ) else: layer_outputs = decoder_layer( @@ -638,7 +632,6 @@ class Cohere2Model(Cohere2PreTrainedModel): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - last_cache_position=last_cache_position, **flash_attn_kwargs, ) @@ -928,10 +921,6 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 - if ( isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2 diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 3d1bdaeca9..85a8d04a50 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -23,15 +23,12 @@ import torch.utils.checkpoint from ...cache_utils import Cache, HybridCache from ...configuration_utils import PretrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import ( - BaseModelOutputWithPast, -) +from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import ( - logging, -) +from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ..cohere.modeling_cohere import ( CohereAttention, CohereDecoderLayer, @@ -45,6 +42,9 @@ from ..cohere.modeling_cohere import ( from ..gemma2.modeling_gemma2 import Gemma2Model +COHERE2_INPUTS_DOCSTRING = None # Will be picked up by modular + + logger = logging.get_logger(__name__) @@ -351,6 +351,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer): self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0 self.sliding_window = config.sliding_window + @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( self, hidden_states: torch.Tensor, @@ -360,7 +361,6 @@ class Cohere2DecoderLayer(CohereDecoderLayer): output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - last_cache_position: int = 0, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -381,7 +381,6 @@ class Cohere2DecoderLayer(CohereDecoderLayer): (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence - last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing """ if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding @@ -400,11 +399,16 @@ class Cohere2DecoderLayer(CohereDecoderLayer): ) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) # In case we are beyond the sliding window, we need to correctly offset the mask slicing - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo - offset = last_cache_position - effective_seq_len + offset = cache_position[-1] - effective_seq_len + 1 # Should only be used when beyond the sliding window (i.e. offset > 0) offset = max(0, offset) - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] + # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, + # but without data-dependent slicing (i.e. torch.compile friendly) + mask_indexes = torch.arange( + min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device + ) + mask_indexes += offset + attention_mask = attention_mask[:, :, :, mask_indexes] residual = hidden_states @@ -452,6 +456,9 @@ class Cohere2Model(Gemma2Model): self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) self.rotary_emb = Cohere2RotaryEmbedding(config=config) + @can_return_tuple + @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING) + @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -463,7 +470,6 @@ class Cohere2Model(Gemma2Model): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -503,16 +509,6 @@ class Cohere2Model(Gemma2Model): if position_ids is None: position_ids = cache_position.unsqueeze(0) - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - if last_cache_position is None: - last_cache_position = 0 - if attention_mask is not None: - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) - last_cache_position = ( - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() - ) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) @@ -540,7 +536,6 @@ class Cohere2Model(Gemma2Model): output_attentions, use_cache, cache_position, - last_cache_position, ) else: layer_outputs = decoder_layer( @@ -551,7 +546,6 @@ class Cohere2Model(Gemma2Model): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - last_cache_position=last_cache_position, **flash_attn_kwargs, ) @@ -625,10 +619,6 @@ class Cohere2ForCausalLM(CohereForCausalLM): # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 - if ( isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2 diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index f0d340048f..353b171042 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -47,6 +47,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_gemma2 import Gemma2Config @@ -285,6 +286,7 @@ class Gemma2DecoderLayer(nn.Module): self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window + @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( self, hidden_states: torch.Tensor, @@ -295,7 +297,6 @@ class Gemma2DecoderLayer(nn.Module): output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - last_cache_position: int = 0, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding @@ -314,11 +315,16 @@ class Gemma2DecoderLayer(nn.Module): ) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) # In case we are beyond the sliding window, we need to correctly offset the mask slicing - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo - offset = last_cache_position - effective_seq_len + offset = cache_position[-1] - effective_seq_len + 1 # Should only be used when beyond the sliding window (i.e. offset > 0) offset = max(0, offset) - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] + # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, + # but without data-dependent slicing (i.e. torch.compile friendly) + mask_indexes = torch.arange( + min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device + ) + mask_indexes += offset + attention_mask = attention_mask[:, :, :, mask_indexes] residual = hidden_states @@ -542,6 +548,7 @@ class Gemma2Model(Gemma2PreTrainedModel): @can_return_tuple @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) + @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -553,7 +560,6 @@ class Gemma2Model(Gemma2PreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -594,16 +600,6 @@ class Gemma2Model(Gemma2PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - if last_cache_position is None: - last_cache_position = 0 - if attention_mask is not None: - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) - last_cache_position = ( - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() - ) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) @@ -639,7 +635,6 @@ class Gemma2Model(Gemma2PreTrainedModel): output_attentions, use_cache, cache_position, - last_cache_position, ) else: layer_outputs = decoder_layer( @@ -651,7 +646,6 @@ class Gemma2Model(Gemma2PreTrainedModel): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - last_cache_position=last_cache_position, **flash_attn_kwargs, ) @@ -922,9 +916,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): **kwargs, ) - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 if logits_to_keep is None: _ = model_inputs.pop("logits_to_keep", None) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index e06a701fc5..b219384f34 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -24,13 +24,11 @@ from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache, StaticCache from ...configuration_utils import PretrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import is_torch_flex_attn_available, logging +from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from ..gemma.modeling_gemma import ( GemmaAttention, GemmaForCausalLM, @@ -45,6 +43,7 @@ from ..gemma.modeling_gemma import ( _CHECKPOINT_FOR_DOC = "google/gemma2-7b" +GEMMA2_INPUTS_DOCSTRING = None # Will be picked up by modular if is_torch_flex_attn_available(): @@ -334,6 +333,7 @@ class Gemma2DecoderLayer(nn.Module): self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window + @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( self, hidden_states: torch.Tensor, @@ -344,7 +344,6 @@ class Gemma2DecoderLayer(nn.Module): output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - last_cache_position: int = 0, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding @@ -363,11 +362,16 @@ class Gemma2DecoderLayer(nn.Module): ) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) # In case we are beyond the sliding window, we need to correctly offset the mask slicing - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo - offset = last_cache_position - effective_seq_len + offset = cache_position[-1] - effective_seq_len + 1 # Should only be used when beyond the sliding window (i.e. offset > 0) offset = max(0, offset) - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] + # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, + # but without data-dependent slicing (i.e. torch.compile friendly) + mask_indexes = torch.arange( + min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device + ) + mask_indexes += offset + attention_mask = attention_mask[:, :, :, mask_indexes] residual = hidden_states @@ -409,6 +413,9 @@ class Gemma2Model(GemmaModel): [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) + @can_return_tuple + @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) + @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -420,7 +427,6 @@ class Gemma2Model(GemmaModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -461,16 +467,6 @@ class Gemma2Model(GemmaModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - if last_cache_position is None: - last_cache_position = 0 - if attention_mask is not None: - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) - last_cache_position = ( - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() - ) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) @@ -506,7 +502,6 @@ class Gemma2Model(GemmaModel): output_attentions, use_cache, cache_position, - last_cache_position, ) else: layer_outputs = decoder_layer( @@ -518,7 +513,6 @@ class Gemma2Model(GemmaModel): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - last_cache_position=last_cache_position, **flash_attn_kwargs, ) @@ -702,9 +696,6 @@ class Gemma2ForCausalLM(GemmaForCausalLM): **kwargs, ) - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 if logits_to_keep is None: _ = model_inputs.pop("logits_to_keep", None) diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 50ca08a3f1..170e3d952f 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -45,6 +45,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig @@ -377,6 +378,7 @@ class Gemma3DecoderLayer(nn.Module): self.is_sliding = self.self_attn.is_sliding self.sliding_window = config.sliding_window + @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( self, hidden_states: torch.Tensor, @@ -388,7 +390,6 @@ class Gemma3DecoderLayer(nn.Module): output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - last_cache_position: int = 0, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding @@ -407,11 +408,16 @@ class Gemma3DecoderLayer(nn.Module): ) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) # In case we are beyond the sliding window, we need to correctly offset the mask slicing - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo - offset = last_cache_position - effective_seq_len + offset = cache_position[-1] - effective_seq_len + 1 # Should only be used when beyond the sliding window (i.e. offset > 0) offset = max(0, offset) - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] + # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, + # but without data-dependent slicing (i.e. torch.compile friendly) + mask_indexes = torch.arange( + min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device + ) + mask_indexes += offset + attention_mask = attention_mask[:, :, :, mask_indexes] residual = hidden_states @@ -626,6 +632,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel): @can_return_tuple @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) + @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -637,7 +644,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -678,16 +684,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - if last_cache_position is None: - last_cache_position = 0 - if attention_mask is not None: - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) - last_cache_position = ( - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() - ) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, @@ -723,7 +719,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel): output_attentions, use_cache, cache_position, - last_cache_position, ) else: layer_outputs = decoder_layer( @@ -736,7 +731,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - last_cache_position=last_cache_position, **flash_attn_kwargs, ) @@ -1009,9 +1003,6 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): **kwargs, ) - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 if logits_to_keep is None: _ = model_inputs.pop("logits_to_keep", None) diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index f2e716f216..7c95f63b0e 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -26,11 +26,7 @@ import torch.utils.checkpoint from ...cache_utils import Cache, HybridCache, StaticCache from ...configuration_utils import PretrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - ModelOutput, -) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack @@ -41,6 +37,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ..gemma2.configuration_gemma2 import Gemma2Config from ..gemma2.modeling_gemma2 import ( Gemma2Attention, @@ -460,6 +457,7 @@ class Gemma3DecoderLayer(nn.Module): self.is_sliding = self.self_attn.is_sliding self.sliding_window = config.sliding_window + @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( self, hidden_states: torch.Tensor, @@ -471,7 +469,6 @@ class Gemma3DecoderLayer(nn.Module): output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - last_cache_position: int = 0, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding @@ -490,11 +487,16 @@ class Gemma3DecoderLayer(nn.Module): ) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) # In case we are beyond the sliding window, we need to correctly offset the mask slicing - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo - offset = last_cache_position - effective_seq_len + offset = cache_position[-1] - effective_seq_len + 1 # Should only be used when beyond the sliding window (i.e. offset > 0) offset = max(0, offset) - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] + # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, + # but without data-dependent slicing (i.e. torch.compile friendly) + mask_indexes = torch.arange( + min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device + ) + mask_indexes += offset + attention_mask = attention_mask[:, :, :, mask_indexes] residual = hidden_states @@ -581,6 +583,9 @@ class Gemma3TextModel(Gemma2Model): config.rope_scaling = {"rope_type": "default"} self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) + @can_return_tuple + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) + @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -592,7 +597,6 @@ class Gemma3TextModel(Gemma2Model): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -633,16 +637,6 @@ class Gemma3TextModel(Gemma2Model): if position_ids is None: position_ids = cache_position.unsqueeze(0) - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing - # (retrieving the same value from `cache_position` later on would crash dynamo) - if last_cache_position is None: - last_cache_position = 0 - if attention_mask is not None: - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) - last_cache_position = ( - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() - ) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, @@ -678,7 +672,6 @@ class Gemma3TextModel(Gemma2Model): output_attentions, use_cache, cache_position, - last_cache_position, ) else: layer_outputs = decoder_layer( @@ -691,7 +684,6 @@ class Gemma3TextModel(Gemma2Model): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - last_cache_position=last_cache_position, **flash_attn_kwargs, ) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0672589769..fa8bd274cc 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2075,9 +2075,6 @@ class GenerationTesterMixin: Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. ⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️ """ - # Monkey-patching the HybridCache at test-time to continue testing compilation support - HybridCache.is_compileable = True - for model_class in self.all_generative_model_classes: if not model_class._supports_static_cache: self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") @@ -2174,9 +2171,6 @@ class GenerationTesterMixin: Tests that all optional outputs are behaving as expected when compilation is triggered. In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered. """ - # Monkey-patching the HybridCache at test-time to continue testing compilation support - HybridCache.is_compileable = True - for model_class in self.all_generative_model_classes: if not model_class._supports_static_cache: self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 3a6093e637..d1ba0cbec4 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -154,6 +154,10 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): def test_multi_gpu_data_parallel_forward(self): pass + @unittest.skip("Gemma2 has HybridCache which auto-compiles. Compile and FA2 don't work together.") + def test_eager_matches_fa2_generate(self): + pass + @slow @require_torch_accelerator diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 935d8b884a..be83749cf8 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -329,6 +329,10 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte def test_generate_from_inputs_embeds_with_static_cache(self): pass + @unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.") + def test_eager_matches_fa2_generate(self): + pass + @unittest.skip( reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation" )