[Gemma3] compile (#37447)

This commit is contained in:
Joao Gante
2025-04-18 14:55:43 +01:00
committed by GitHub
parent a1b82563f1
commit e5ac23081e
10 changed files with 90 additions and 148 deletions

View File

@@ -1654,9 +1654,7 @@ class HybridCache(Cache):
``` ```
""" """
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert is_compileable = True
# ALL changes from the PR that commented the line below when reactivating it.
# is_compileable = True
def __init__( def __init__(
self, self,
@@ -1858,8 +1856,6 @@ class HybridChunkedCache(Cache):
``` ```
""" """
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
# ALL changes from the PR that commented the line below when reactivating it.
is_compileable = True is_compileable = True
def __init__( def __init__(

View File

@@ -42,6 +42,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.deprecation import deprecate_kwarg
from .configuration_cohere2 import Cohere2Config from .configuration_cohere2 import Cohere2Config
@@ -300,6 +301,7 @@ class Cohere2DecoderLayer(nn.Module):
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0 self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
self.sliding_window = config.sliding_window self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -309,7 +311,6 @@ class Cohere2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
@@ -330,7 +331,6 @@ class Cohere2DecoderLayer(nn.Module):
(see `past_key_values`). (see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence Indices depicting the position of the input sequence tokens in the sequence
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
""" """
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -349,11 +349,16 @@ class Cohere2DecoderLayer(nn.Module):
) )
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing # In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo offset = cache_position[-1] - effective_seq_len + 1
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0) # Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset) offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states residual = hidden_states
@@ -539,6 +544,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
@can_return_tuple @can_return_tuple
@add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@@ -550,7 +556,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast: ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -590,16 +595,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
) )
@@ -627,7 +622,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
last_cache_position,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -638,7 +632,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs, **flash_attn_kwargs,
) )
@@ -928,10 +921,6 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
# The clone here is for the same reason as for `position_ids`. # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if ( if (
isinstance(past_key_values, HybridCache) isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2 and attention_mask.ndim == 2

View File

@@ -23,15 +23,12 @@ import torch.utils.checkpoint
from ...cache_utils import Cache, HybridCache from ...cache_utils import Cache, HybridCache
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import BaseModelOutputWithPast
BaseModelOutputWithPast,
)
from ...modeling_rope_utils import rope_config_validation from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import ( from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, logging
logging, from ...utils.deprecation import deprecate_kwarg
)
from ..cohere.modeling_cohere import ( from ..cohere.modeling_cohere import (
CohereAttention, CohereAttention,
CohereDecoderLayer, CohereDecoderLayer,
@@ -45,6 +42,9 @@ from ..cohere.modeling_cohere import (
from ..gemma2.modeling_gemma2 import Gemma2Model from ..gemma2.modeling_gemma2 import Gemma2Model
COHERE2_INPUTS_DOCSTRING = None # Will be picked up by modular
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -351,6 +351,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0 self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
self.sliding_window = config.sliding_window self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -360,7 +361,6 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
@@ -381,7 +381,6 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
(see `past_key_values`). (see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence Indices depicting the position of the input sequence tokens in the sequence
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
""" """
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -400,11 +399,16 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
) )
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing # In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo offset = cache_position[-1] - effective_seq_len + 1
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0) # Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset) offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states residual = hidden_states
@@ -452,6 +456,9 @@ class Cohere2Model(Gemma2Model):
self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
self.rotary_emb = Cohere2RotaryEmbedding(config=config) self.rotary_emb = Cohere2RotaryEmbedding(config=config)
@can_return_tuple
@add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@@ -463,7 +470,6 @@ class Cohere2Model(Gemma2Model):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast: ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -503,16 +509,6 @@ class Cohere2Model(Gemma2Model):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
) )
@@ -540,7 +536,6 @@ class Cohere2Model(Gemma2Model):
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
last_cache_position,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -551,7 +546,6 @@ class Cohere2Model(Gemma2Model):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs, **flash_attn_kwargs,
) )
@@ -625,10 +619,6 @@ class Cohere2ForCausalLM(CohereForCausalLM):
# The clone here is for the same reason as for `position_ids`. # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if ( if (
isinstance(past_key_values, HybridCache) isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2 and attention_mask.ndim == 2

View File

@@ -47,6 +47,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.deprecation import deprecate_kwarg
from .configuration_gemma2 import Gemma2Config from .configuration_gemma2 import Gemma2Config
@@ -285,6 +286,7 @@ class Gemma2DecoderLayer(nn.Module):
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -295,7 +297,6 @@ class Gemma2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -314,11 +315,16 @@ class Gemma2DecoderLayer(nn.Module):
) )
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing # In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo offset = cache_position[-1] - effective_seq_len + 1
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0) # Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset) offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states residual = hidden_states
@@ -542,6 +548,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
@can_return_tuple @can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@@ -553,7 +560,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast: ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -594,16 +600,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
) )
@@ -639,7 +635,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
last_cache_position,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -651,7 +646,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs, **flash_attn_kwargs,
) )
@@ -922,9 +916,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
**kwargs, **kwargs,
) )
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if logits_to_keep is None: if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None) _ = model_inputs.pop("logits_to_keep", None)

View File

@@ -24,13 +24,11 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import is_torch_flex_attn_available, logging from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, is_torch_flex_attn_available, logging
from ...utils.deprecation import deprecate_kwarg
from ..gemma.modeling_gemma import ( from ..gemma.modeling_gemma import (
GemmaAttention, GemmaAttention,
GemmaForCausalLM, GemmaForCausalLM,
@@ -45,6 +43,7 @@ from ..gemma.modeling_gemma import (
_CHECKPOINT_FOR_DOC = "google/gemma2-7b" _CHECKPOINT_FOR_DOC = "google/gemma2-7b"
GEMMA2_INPUTS_DOCSTRING = None # Will be picked up by modular
if is_torch_flex_attn_available(): if is_torch_flex_attn_available():
@@ -334,6 +333,7 @@ class Gemma2DecoderLayer(nn.Module):
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -344,7 +344,6 @@ class Gemma2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -363,11 +362,16 @@ class Gemma2DecoderLayer(nn.Module):
) )
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing # In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo offset = cache_position[-1] - effective_seq_len + 1
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0) # Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset) offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states residual = hidden_states
@@ -409,6 +413,9 @@ class Gemma2Model(GemmaModel):
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
) )
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@@ -420,7 +427,6 @@ class Gemma2Model(GemmaModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast: ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -461,16 +467,6 @@ class Gemma2Model(GemmaModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
) )
@@ -506,7 +502,6 @@ class Gemma2Model(GemmaModel):
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
last_cache_position,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -518,7 +513,6 @@ class Gemma2Model(GemmaModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs, **flash_attn_kwargs,
) )
@@ -702,9 +696,6 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
**kwargs, **kwargs,
) )
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if logits_to_keep is None: if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None) _ = model_inputs.pop("logits_to_keep", None)

View File

@@ -45,6 +45,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.deprecation import deprecate_kwarg
from ..auto import AutoModel, AutoModelForCausalLM from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
@@ -377,6 +378,7 @@ class Gemma3DecoderLayer(nn.Module):
self.is_sliding = self.self_attn.is_sliding self.is_sliding = self.self_attn.is_sliding
self.sliding_window = config.sliding_window self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -388,7 +390,6 @@ class Gemma3DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs, **kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -407,11 +408,16 @@ class Gemma3DecoderLayer(nn.Module):
) )
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing # In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo offset = cache_position[-1] - effective_seq_len + 1
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0) # Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset) offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states residual = hidden_states
@@ -626,6 +632,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
@can_return_tuple @can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@@ -637,7 +644,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast: ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -678,16 +684,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
attention_mask, attention_mask,
inputs_embeds, inputs_embeds,
@@ -723,7 +719,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
last_cache_position,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -736,7 +731,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs, **flash_attn_kwargs,
) )
@@ -1009,9 +1003,6 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
**kwargs, **kwargs,
) )
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if logits_to_keep is None: if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None) _ = model_inputs.pop("logits_to_keep", None)

View File

@@ -26,11 +26,7 @@ import torch.utils.checkpoint
from ...cache_utils import Cache, HybridCache, StaticCache from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
BaseModelOutputWithPast,
CausalLMOutputWithPast,
ModelOutput,
)
from ...modeling_rope_utils import rope_config_validation from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack from ...processing_utils import Unpack
@@ -41,6 +37,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.deprecation import deprecate_kwarg
from ..gemma2.configuration_gemma2 import Gemma2Config from ..gemma2.configuration_gemma2 import Gemma2Config
from ..gemma2.modeling_gemma2 import ( from ..gemma2.modeling_gemma2 import (
Gemma2Attention, Gemma2Attention,
@@ -460,6 +457,7 @@ class Gemma3DecoderLayer(nn.Module):
self.is_sliding = self.self_attn.is_sliding self.is_sliding = self.self_attn.is_sliding
self.sliding_window = config.sliding_window self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -471,7 +469,6 @@ class Gemma3DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs, **kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -490,11 +487,16 @@ class Gemma3DecoderLayer(nn.Module):
) )
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing # In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo offset = cache_position[-1] - effective_seq_len + 1
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0) # Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset) offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states residual = hidden_states
@@ -581,6 +583,9 @@ class Gemma3TextModel(Gemma2Model):
config.rope_scaling = {"rope_type": "default"} config.rope_scaling = {"rope_type": "default"}
self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@@ -592,7 +597,6 @@ class Gemma3TextModel(Gemma2Model):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast: ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -633,16 +637,6 @@ class Gemma3TextModel(Gemma2Model):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
attention_mask, attention_mask,
inputs_embeds, inputs_embeds,
@@ -678,7 +672,6 @@ class Gemma3TextModel(Gemma2Model):
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
last_cache_position,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -691,7 +684,6 @@ class Gemma3TextModel(Gemma2Model):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs, **flash_attn_kwargs,
) )

View File

@@ -2075,9 +2075,6 @@ class GenerationTesterMixin:
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️ ⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
""" """
# Monkey-patching the HybridCache at test-time to continue testing compilation support
HybridCache.is_compileable = True
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache: if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
@@ -2174,9 +2171,6 @@ class GenerationTesterMixin:
Tests that all optional outputs are behaving as expected when compilation is triggered. Tests that all optional outputs are behaving as expected when compilation is triggered.
In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered. In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered.
""" """
# Monkey-patching the HybridCache at test-time to continue testing compilation support
HybridCache.is_compileable = True
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache: if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")

View File

@@ -154,6 +154,10 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_multi_gpu_data_parallel_forward(self): def test_multi_gpu_data_parallel_forward(self):
pass pass
@unittest.skip("Gemma2 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
def test_eager_matches_fa2_generate(self):
pass
@slow @slow
@require_torch_accelerator @require_torch_accelerator

View File

@@ -329,6 +329,10 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
def test_generate_from_inputs_embeds_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self):
pass pass
@unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip( @unittest.skip(
reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation" reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation"
) )