[Gemma3] compile (#37447)

This commit is contained in:
Joao Gante
2025-04-18 14:55:43 +01:00
committed by GitHub
parent a1b82563f1
commit e5ac23081e
10 changed files with 90 additions and 148 deletions

View File

@@ -1654,9 +1654,7 @@ class HybridCache(Cache):
```
"""
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
# ALL changes from the PR that commented the line below when reactivating it.
# is_compileable = True
is_compileable = True
def __init__(
self,
@@ -1858,8 +1856,6 @@ class HybridChunkedCache(Cache):
```
"""
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
# ALL changes from the PR that commented the line below when reactivating it.
is_compileable = True
def __init__(

View File

@@ -42,6 +42,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from .configuration_cohere2 import Cohere2Config
@@ -300,6 +301,7 @@ class Cohere2DecoderLayer(nn.Module):
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
@@ -309,7 +311,6 @@ class Cohere2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
@@ -330,7 +331,6 @@ class Cohere2DecoderLayer(nn.Module):
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
"""
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -349,11 +349,16 @@ class Cohere2DecoderLayer(nn.Module):
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states
@@ -539,6 +544,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
@can_return_tuple
@add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -550,7 +556,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -590,16 +595,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
@@ -627,7 +622,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -638,7 +632,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -928,10 +921,6 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2

View File

@@ -23,15 +23,12 @@ import torch.utils.checkpoint
from ...cache_utils import Cache, HybridCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
)
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import (
logging,
)
from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg
from ..cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer,
@@ -45,6 +42,9 @@ from ..cohere.modeling_cohere import (
from ..gemma2.modeling_gemma2 import Gemma2Model
COHERE2_INPUTS_DOCSTRING = None # Will be picked up by modular
logger = logging.get_logger(__name__)
@@ -351,6 +351,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
@@ -360,7 +361,6 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
@@ -381,7 +381,6 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
"""
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -400,11 +399,16 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states
@@ -452,6 +456,9 @@ class Cohere2Model(Gemma2Model):
self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
self.rotary_emb = Cohere2RotaryEmbedding(config=config)
@can_return_tuple
@add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -463,7 +470,6 @@ class Cohere2Model(Gemma2Model):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -503,16 +509,6 @@ class Cohere2Model(Gemma2Model):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
@@ -540,7 +536,6 @@ class Cohere2Model(Gemma2Model):
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -551,7 +546,6 @@ class Cohere2Model(Gemma2Model):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -625,10 +619,6 @@ class Cohere2ForCausalLM(CohereForCausalLM):
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2

View File

@@ -47,6 +47,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from .configuration_gemma2 import Gemma2Config
@@ -285,6 +286,7 @@ class Gemma2DecoderLayer(nn.Module):
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
@@ -295,7 +297,6 @@ class Gemma2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -314,11 +315,16 @@ class Gemma2DecoderLayer(nn.Module):
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states
@@ -542,6 +548,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -553,7 +560,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -594,16 +600,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
@@ -639,7 +635,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -651,7 +646,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -922,9 +916,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
**kwargs,
)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None)

View File

@@ -24,13 +24,11 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import is_torch_flex_attn_available, logging
from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, is_torch_flex_attn_available, logging
from ...utils.deprecation import deprecate_kwarg
from ..gemma.modeling_gemma import (
GemmaAttention,
GemmaForCausalLM,
@@ -45,6 +43,7 @@ from ..gemma.modeling_gemma import (
_CHECKPOINT_FOR_DOC = "google/gemma2-7b"
GEMMA2_INPUTS_DOCSTRING = None # Will be picked up by modular
if is_torch_flex_attn_available():
@@ -334,6 +333,7 @@ class Gemma2DecoderLayer(nn.Module):
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
@@ -344,7 +344,6 @@ class Gemma2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -363,11 +362,16 @@ class Gemma2DecoderLayer(nn.Module):
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states
@@ -409,6 +413,9 @@ class Gemma2Model(GemmaModel):
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -420,7 +427,6 @@ class Gemma2Model(GemmaModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -461,16 +467,6 @@ class Gemma2Model(GemmaModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
@@ -506,7 +502,6 @@ class Gemma2Model(GemmaModel):
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -518,7 +513,6 @@ class Gemma2Model(GemmaModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -702,9 +696,6 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
**kwargs,
)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None)

View File

@@ -45,6 +45,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
@@ -377,6 +378,7 @@ class Gemma3DecoderLayer(nn.Module):
self.is_sliding = self.self_attn.is_sliding
self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
@@ -388,7 +390,6 @@ class Gemma3DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -407,11 +408,16 @@ class Gemma3DecoderLayer(nn.Module):
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states
@@ -626,6 +632,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -637,7 +644,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -678,16 +684,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
@@ -723,7 +719,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -736,7 +731,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -1009,9 +1003,6 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
**kwargs,
)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None)

View File

@@ -26,11 +26,7 @@ import torch.utils.checkpoint
from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
ModelOutput,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
@@ -41,6 +37,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ..gemma2.configuration_gemma2 import Gemma2Config
from ..gemma2.modeling_gemma2 import (
Gemma2Attention,
@@ -460,6 +457,7 @@ class Gemma3DecoderLayer(nn.Module):
self.is_sliding = self.self_attn.is_sliding
self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
@@ -471,7 +469,6 @@ class Gemma3DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -490,11 +487,16 @@ class Gemma3DecoderLayer(nn.Module):
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states
@@ -581,6 +583,9 @@ class Gemma3TextModel(Gemma2Model):
config.rope_scaling = {"rope_type": "default"}
self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -592,7 +597,6 @@ class Gemma3TextModel(Gemma2Model):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -633,16 +637,6 @@ class Gemma3TextModel(Gemma2Model):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
@@ -678,7 +672,6 @@ class Gemma3TextModel(Gemma2Model):
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -691,7 +684,6 @@ class Gemma3TextModel(Gemma2Model):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)

View File

@@ -2075,9 +2075,6 @@ class GenerationTesterMixin:
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
"""
# Monkey-patching the HybridCache at test-time to continue testing compilation support
HybridCache.is_compileable = True
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
@@ -2174,9 +2171,6 @@ class GenerationTesterMixin:
Tests that all optional outputs are behaving as expected when compilation is triggered.
In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered.
"""
# Monkey-patching the HybridCache at test-time to continue testing compilation support
HybridCache.is_compileable = True
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")

View File

@@ -154,6 +154,10 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip("Gemma2 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
def test_eager_matches_fa2_generate(self):
pass
@slow
@require_torch_accelerator

View File

@@ -329,6 +329,10 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip(
reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation"
)