[Gemma3] compile ✨ (#37447)
This commit is contained in:
@@ -1654,9 +1654,7 @@ class HybridCache(Cache):
|
||||
```
|
||||
"""
|
||||
|
||||
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
|
||||
# ALL changes from the PR that commented the line below when reactivating it.
|
||||
# is_compileable = True
|
||||
is_compileable = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1858,8 +1856,6 @@ class HybridChunkedCache(Cache):
|
||||
```
|
||||
"""
|
||||
|
||||
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
|
||||
# ALL changes from the PR that commented the line below when reactivating it.
|
||||
is_compileable = True
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -42,6 +42,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_cohere2 import Cohere2Config
|
||||
|
||||
|
||||
@@ -300,6 +301,7 @@ class Cohere2DecoderLayer(nn.Module):
|
||||
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -309,7 +311,6 @@ class Cohere2DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@@ -330,7 +331,6 @@ class Cohere2DecoderLayer(nn.Module):
|
||||
(see `past_key_values`).
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
|
||||
"""
|
||||
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
@@ -349,11 +349,16 @@ class Cohere2DecoderLayer(nn.Module):
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
|
||||
)
|
||||
mask_indexes += offset
|
||||
attention_mask = attention_mask[:, :, :, mask_indexes]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -539,6 +544,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING)
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -550,7 +556,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@@ -590,16 +595,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
@@ -627,7 +622,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@@ -638,7 +632,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
@@ -928,10 +921,6 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
|
||||
@@ -23,15 +23,12 @@ import torch.utils.checkpoint
|
||||
from ...cache_utils import Cache, HybridCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
logging,
|
||||
)
|
||||
from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ..cohere.modeling_cohere import (
|
||||
CohereAttention,
|
||||
CohereDecoderLayer,
|
||||
@@ -45,6 +42,9 @@ from ..cohere.modeling_cohere import (
|
||||
from ..gemma2.modeling_gemma2 import Gemma2Model
|
||||
|
||||
|
||||
COHERE2_INPUTS_DOCSTRING = None # Will be picked up by modular
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -351,6 +351,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
|
||||
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -360,7 +361,6 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@@ -381,7 +381,6 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
|
||||
(see `past_key_values`).
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
|
||||
"""
|
||||
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
@@ -400,11 +399,16 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
|
||||
)
|
||||
mask_indexes += offset
|
||||
attention_mask = attention_mask[:, :, :, mask_indexes]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -452,6 +456,9 @@ class Cohere2Model(Gemma2Model):
|
||||
self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
|
||||
self.rotary_emb = Cohere2RotaryEmbedding(config=config)
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING)
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -463,7 +470,6 @@ class Cohere2Model(Gemma2Model):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@@ -503,16 +509,6 @@ class Cohere2Model(Gemma2Model):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
@@ -540,7 +536,6 @@ class Cohere2Model(Gemma2Model):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@@ -551,7 +546,6 @@ class Cohere2Model(Gemma2Model):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
@@ -625,10 +619,6 @@ class Cohere2ForCausalLM(CohereForCausalLM):
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
|
||||
@@ -47,6 +47,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_gemma2 import Gemma2Config
|
||||
|
||||
|
||||
@@ -285,6 +286,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -295,7 +297,6 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
@@ -314,11 +315,16 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
|
||||
)
|
||||
mask_indexes += offset
|
||||
attention_mask = attention_mask[:, :, :, mask_indexes]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -542,6 +548,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -553,7 +560,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@@ -594,16 +600,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
@@ -639,7 +635,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@@ -651,7 +646,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
@@ -922,9 +916,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
if logits_to_keep is None:
|
||||
_ = model_inputs.pop("logits_to_keep", None)
|
||||
|
||||
|
||||
@@ -24,13 +24,11 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import is_torch_flex_attn_available, logging
|
||||
from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ..gemma.modeling_gemma import (
|
||||
GemmaAttention,
|
||||
GemmaForCausalLM,
|
||||
@@ -45,6 +43,7 @@ from ..gemma.modeling_gemma import (
|
||||
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "google/gemma2-7b"
|
||||
GEMMA2_INPUTS_DOCSTRING = None # Will be picked up by modular
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
@@ -334,6 +333,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -344,7 +344,6 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
@@ -363,11 +362,16 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
|
||||
)
|
||||
mask_indexes += offset
|
||||
attention_mask = attention_mask[:, :, :, mask_indexes]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -409,6 +413,9 @@ class Gemma2Model(GemmaModel):
|
||||
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -420,7 +427,6 @@ class Gemma2Model(GemmaModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@@ -461,16 +467,6 @@ class Gemma2Model(GemmaModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
@@ -506,7 +502,6 @@ class Gemma2Model(GemmaModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@@ -518,7 +513,6 @@ class Gemma2Model(GemmaModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
@@ -702,9 +696,6 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
if logits_to_keep is None:
|
||||
_ = model_inputs.pop("logits_to_keep", None)
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
|
||||
|
||||
@@ -377,6 +378,7 @@ class Gemma3DecoderLayer(nn.Module):
|
||||
self.is_sliding = self.self_attn.is_sliding
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -388,7 +390,6 @@ class Gemma3DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
@@ -407,11 +408,16 @@ class Gemma3DecoderLayer(nn.Module):
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
|
||||
)
|
||||
mask_indexes += offset
|
||||
attention_mask = attention_mask[:, :, :, mask_indexes]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -626,6 +632,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -637,7 +644,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@@ -678,16 +684,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
@@ -723,7 +719,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@@ -736,7 +731,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
@@ -1009,9 +1003,6 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
if logits_to_keep is None:
|
||||
_ = model_inputs.pop("logits_to_keep", None)
|
||||
|
||||
|
||||
@@ -26,11 +26,7 @@ import torch.utils.checkpoint
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
ModelOutput,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
@@ -41,6 +37,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ..gemma2.configuration_gemma2 import Gemma2Config
|
||||
from ..gemma2.modeling_gemma2 import (
|
||||
Gemma2Attention,
|
||||
@@ -460,6 +457,7 @@ class Gemma3DecoderLayer(nn.Module):
|
||||
self.is_sliding = self.self_attn.is_sliding
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -471,7 +469,6 @@ class Gemma3DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
@@ -490,11 +487,16 @@ class Gemma3DecoderLayer(nn.Module):
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
|
||||
)
|
||||
mask_indexes += offset
|
||||
attention_mask = attention_mask[:, :, :, mask_indexes]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -581,6 +583,9 @@ class Gemma3TextModel(Gemma2Model):
|
||||
config.rope_scaling = {"rope_type": "default"}
|
||||
self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -592,7 +597,6 @@ class Gemma3TextModel(Gemma2Model):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@@ -633,16 +637,6 @@ class Gemma3TextModel(Gemma2Model):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
@@ -678,7 +672,6 @@ class Gemma3TextModel(Gemma2Model):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@@ -691,7 +684,6 @@ class Gemma3TextModel(Gemma2Model):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -2075,9 +2075,6 @@ class GenerationTesterMixin:
|
||||
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
|
||||
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
|
||||
"""
|
||||
# Monkey-patching the HybridCache at test-time to continue testing compilation support
|
||||
HybridCache.is_compileable = True
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
|
||||
@@ -2174,9 +2171,6 @@ class GenerationTesterMixin:
|
||||
Tests that all optional outputs are behaving as expected when compilation is triggered.
|
||||
In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered.
|
||||
"""
|
||||
# Monkey-patching the HybridCache at test-time to continue testing compilation support
|
||||
HybridCache.is_compileable = True
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
|
||||
|
||||
@@ -154,6 +154,10 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
|
||||
def test_eager_matches_fa2_generate(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
|
||||
@@ -329,6 +329,10 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
|
||||
def test_eager_matches_fa2_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user