Fix Cache.max_cache_len max value for Hybrid models (#39737)

* fix gemma

* fix min

* fix quant init issue

* fix gemma 3n

* skip quant cache test

* fix modular

* new test for Gemma

* include cyril change

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
Manuel de Prada Corral
2025-07-29 17:12:50 +02:00
committed by GitHub
parent 075dbbceaa
commit c4e2069898
5 changed files with 82 additions and 40 deletions

View File

@@ -325,8 +325,9 @@ class SlidingWindowLayer(StaticLayer):
sliding_window (`int`): sliding_window (`int`):
Effective window size: number of tokens that are kept on each update call. Effective window size: number of tokens that are kept on each update call.
""" """
kwargs.pop("max_cache_len", None) max_cache_len = kwargs.pop("max_cache_len", None)
super().__init__(*args, max_cache_len=sliding_window, *args, **kwargs) max_cache_len = min(sliding_window, max_cache_len) if max_cache_len is not None else sliding_window
super().__init__(*args, max_cache_len=max_cache_len, *args, **kwargs)
def update( def update(
self, self,
@@ -1277,9 +1278,7 @@ class Cache:
def max_cache_len(self) -> int: def max_cache_len(self) -> int:
"""Return the maximum cache length of the cache""" """Return the maximum cache length of the cache"""
values = [layer.max_cache_len for layer in self.layers] values = [layer.max_cache_len for layer in self.layers]
if len(set(values)) > 1: return max(values)
raise ValueError(f"Max cache length is not consistent across layers: {values}")
return values[0]
@property @property
def is_compileable(self) -> bool: def is_compileable(self) -> bool:
@@ -1655,7 +1654,7 @@ class QuantoQuantizedCache(QuantizedCache):
""" """
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
Cache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs) DynamicCache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs)
class HQQQuantizedCache(QuantizedCache): class HQQQuantizedCache(QuantizedCache):
@@ -1697,7 +1696,7 @@ class HQQQuantizedCache(QuantizedCache):
def __init__(self, backend="HQQ", **kwargs) -> None: def __init__(self, backend="HQQ", **kwargs) -> None:
assert backend == "HQQ" assert backend == "HQQ"
Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs) DynamicCache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs)
class EncoderDecoderCache(Cache): class EncoderDecoderCache(Cache):
@@ -1951,10 +1950,6 @@ def parse_layer_args_from_model_config(
) )
# Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window)
max_cache_len = max_cache_len or config.max_position_embeddings max_cache_len = max_cache_len or config.max_position_embeddings
if getattr(config, "sliding_window", None) is not None:
sliding_window_len = min(config.sliding_window, max_cache_len)
else:
sliding_window_len = None
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads:
head_dim = ( head_dim = (
config.head_dim config.head_dim
@@ -1981,7 +1976,7 @@ def parse_layer_args_from_model_config(
"layer_device_map": layer_device_map, "layer_device_map": layer_device_map,
"head_dim": head_dim, "head_dim": head_dim,
"num_heads": num_heads, "num_heads": num_heads,
"sliding_window": sliding_window_len, "sliding_window": getattr(config, "sliding_window", None),
} }
return {k: v for k, v in layer_args.items() if v is not None} return {k: v for k, v in layer_args.items() if v is not None}

View File

@@ -30,7 +30,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, HybridCache from ...cache_utils import Cache, DynamicCache, SlidingWindowLayer
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -1327,22 +1327,20 @@ class Gemma3nTextAttention(nn.Module):
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
# Device of past layer may be different from current one
indices = cache_position.to(past_key_value.layers[self.kv_shared_layer_index].keys.device)
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
if isinstance(past_key_value, HybridCache) and self.is_sliding: layer = past_key_value.layers[self.kv_shared_layer_index]
max_length = past_key_value.sliding_window # Device of past layer may be different from current one
indices = ( indices = cache_position.to(layer.keys.device)
slice(0, max_length) # Sliding window cache layers might have smaller size (for full layers, we never go beyond)
if cache_position.shape[0] > max_length if isinstance(layer, SlidingWindowLayer):
else cache_position.clamp(min=0, max=max_length - 1) if cache_position.shape[0] > layer.get_max_cache_shape():
) indices = slice(0, layer.get_max_cache_shape())
else:
indices = indices.clamp(min=0, max=layer.get_max_cache_shape() - 1)
# Device of past layer may be different from current one # Device of past layer may be different from current one
key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices].to(query_states.device) key_states = layer.keys[:, :, indices].to(query_states.device)
value_states = ( value_states = layer.values[:, :, indices].to(query_states.device)
past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices].to(query_states.device)
)
else: else:
key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states) key_states = self.k_norm(key_states)

View File

@@ -23,7 +23,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, HybridCache from ...cache_utils import Cache, DynamicCache, SlidingWindowLayer
from ...configuration_utils import PretrainedConfig, layer_type_validation from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -1769,22 +1769,20 @@ class Gemma3nTextAttention(Gemma3Attention):
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
# Device of past layer may be different from current one
indices = cache_position.to(past_key_value.layers[self.kv_shared_layer_index].keys.device)
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
if isinstance(past_key_value, HybridCache) and self.is_sliding: layer = past_key_value.layers[self.kv_shared_layer_index]
max_length = past_key_value.sliding_window # Device of past layer may be different from current one
indices = ( indices = cache_position.to(layer.keys.device)
slice(0, max_length) # Sliding window cache layers might have smaller size (for full layers, we never go beyond)
if cache_position.shape[0] > max_length if isinstance(layer, SlidingWindowLayer):
else cache_position.clamp(min=0, max=max_length - 1) if cache_position.shape[0] > layer.get_max_cache_shape():
) indices = slice(0, layer.get_max_cache_shape())
else:
indices = indices.clamp(min=0, max=layer.get_max_cache_shape() - 1)
# Device of past layer may be different from current one # Device of past layer may be different from current one
key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices].to(query_states.device) key_states = layer.keys[:, :, indices].to(query_states.device)
value_states = ( value_states = layer.values[:, :, indices].to(query_states.device)
past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices].to(query_states.device)
)
else: else:
key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states) key_states = self.k_norm(key_states)

View File

@@ -151,6 +151,52 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
def test_sdpa_padding_matches_padding_free_with_position_ids(self): def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass pass
def test_generation_beyond_sliding_window_tiny_model(self):
"""Test generation with a tiny randomly initialised model whose input length is larger than the `sliding_window`.
The model is configured with both `full_attention` and `sliding_attention` layers to make sure the hybrid cache
and mask slicing logic is covered.
"""
config = Gemma3TextConfig.from_pretrained("hf-internal-testing/tiny-random-Gemma3ForCausalLM")
config.attn_implementation = "eager"
config.layer_types = ["full_attention", "sliding_attention"]
config.sliding_window = 8
config.max_position_embeddings = 128
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-Gemma3ForCausalLM", config=config
).to(torch_device)
input_len = 10
input_ids = torch.tensor(
[
[42300, 241087, 255445, 81315, 193760, 184471, 67719, 98191, 210651, 124725],
[102294, 205314, 226646, 62020, 60245, 68025, 251839, 114053, 4695, 175511],
],
device=torch_device,
)
attention_mask = torch.ones_like(input_ids).to(torch_device)
with torch.no_grad():
_ = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=1,
do_sample=False,
use_cache=True,
cache_implementation="hybrid",
)
# 2 generations are needed to trigger https://github.com/huggingface/transformers/issues/39711
# Since it requires model._cache to have been previously initialized
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=5,
do_sample=False,
use_cache=True,
cache_implementation="hybrid",
)
generated_sequences = output[:, input_len:].cpu()
EXPECTED_OUTPUT = torch.tensor([[90109, 90109, 90109, 83191, 83191], [246901, 69832, 69832, 69832, 62288]])
torch.testing.assert_close(generated_sequences, EXPECTED_OUTPUT)
class Gemma3Vision2TextModelTester: class Gemma3Vision2TextModelTester:
def __init__( def __init__(

View File

@@ -431,6 +431,11 @@ class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
def test_dola_decoding_sample(self): def test_dola_decoding_sample(self):
pass pass
@pytest.mark.generate
@unittest.skip("Gemma3n does not support QuantizedCache as it performs cache manipulation in the forward pass")
def test_generate_with_quant_cache(self):
pass
class Gemma3nVision2TextModelTester: class Gemma3nVision2TextModelTester:
text_config = {"activation_sparsity_pattern": None} text_config = {"activation_sparsity_pattern": None}