From c4e20698985887215f7e91a02621265f047af2d7 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com> Date: Tue, 29 Jul 2025 17:12:50 +0200 Subject: [PATCH] Fix Cache.max_cache_len max value for Hybrid models (#39737) * fix gemma * fix min * fix quant init issue * fix gemma 3n * skip quant cache test * fix modular * new test for Gemma * include cyril change --------- Co-authored-by: Cyril Vallez --- src/transformers/cache_utils.py | 19 +++----- .../models/gemma3n/modeling_gemma3n.py | 26 +++++------ .../models/gemma3n/modular_gemma3n.py | 26 +++++------ tests/models/gemma3/test_modeling_gemma3.py | 46 +++++++++++++++++++ tests/models/gemma3n/test_modeling_gemma3n.py | 5 ++ 5 files changed, 82 insertions(+), 40 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 8f9a007a10..93c4af7cdc 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -325,8 +325,9 @@ class SlidingWindowLayer(StaticLayer): sliding_window (`int`): Effective window size: number of tokens that are kept on each update call. """ - kwargs.pop("max_cache_len", None) - super().__init__(*args, max_cache_len=sliding_window, *args, **kwargs) + max_cache_len = kwargs.pop("max_cache_len", None) + max_cache_len = min(sliding_window, max_cache_len) if max_cache_len is not None else sliding_window + super().__init__(*args, max_cache_len=max_cache_len, *args, **kwargs) def update( self, @@ -1277,9 +1278,7 @@ class Cache: def max_cache_len(self) -> int: """Return the maximum cache length of the cache""" values = [layer.max_cache_len for layer in self.layers] - if len(set(values)) > 1: - raise ValueError(f"Max cache length is not consistent across layers: {values}") - return values[0] + return max(values) @property def is_compileable(self) -> bool: @@ -1655,7 +1654,7 @@ class QuantoQuantizedCache(QuantizedCache): """ def __init__(self, **kwargs) -> None: - Cache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs) + DynamicCache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs) class HQQQuantizedCache(QuantizedCache): @@ -1697,7 +1696,7 @@ class HQQQuantizedCache(QuantizedCache): def __init__(self, backend="HQQ", **kwargs) -> None: assert backend == "HQQ" - Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs) + DynamicCache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs) class EncoderDecoderCache(Cache): @@ -1951,10 +1950,6 @@ def parse_layer_args_from_model_config( ) # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) max_cache_len = max_cache_len or config.max_position_embeddings - if getattr(config, "sliding_window", None) is not None: - sliding_window_len = min(config.sliding_window, max_cache_len) - else: - sliding_window_len = None # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: head_dim = ( config.head_dim @@ -1981,7 +1976,7 @@ def parse_layer_args_from_model_config( "layer_device_map": layer_device_map, "head_dim": head_dim, "num_heads": num_heads, - "sliding_window": sliding_window_len, + "sliding_window": getattr(config, "sliding_window", None), } return {k: v for k, v in layer_args.items() if v is not None} diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 8a00675487..3cf07819be 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -30,7 +30,7 @@ import torch.nn as nn import torch.nn.functional as F from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, HybridCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowLayer from ...generation import GenerationMixin from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -1327,22 +1327,20 @@ class Gemma3nTextAttention(nn.Module): query_states = query_states.transpose(1, 2) if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: - # Device of past layer may be different from current one - indices = cache_position.to(past_key_value.layers[self.kv_shared_layer_index].keys.device) # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) - if isinstance(past_key_value, HybridCache) and self.is_sliding: - max_length = past_key_value.sliding_window - indices = ( - slice(0, max_length) - if cache_position.shape[0] > max_length - else cache_position.clamp(min=0, max=max_length - 1) - ) + layer = past_key_value.layers[self.kv_shared_layer_index] + # Device of past layer may be different from current one + indices = cache_position.to(layer.keys.device) + # Sliding window cache layers might have smaller size (for full layers, we never go beyond) + if isinstance(layer, SlidingWindowLayer): + if cache_position.shape[0] > layer.get_max_cache_shape(): + indices = slice(0, layer.get_max_cache_shape()) + else: + indices = indices.clamp(min=0, max=layer.get_max_cache_shape() - 1) # Device of past layer may be different from current one - key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices].to(query_states.device) - value_states = ( - past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices].to(query_states.device) - ) + key_states = layer.keys[:, :, indices].to(query_states.device) + value_states = layer.values[:, :, indices].to(query_states.device) else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index bd20faacf5..2716e9bdfe 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -23,7 +23,7 @@ import torch.nn as nn import torch.nn.functional as F from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, HybridCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowLayer from ...configuration_utils import PretrainedConfig, layer_type_validation from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -1769,22 +1769,20 @@ class Gemma3nTextAttention(Gemma3Attention): query_states = query_states.transpose(1, 2) if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: - # Device of past layer may be different from current one - indices = cache_position.to(past_key_value.layers[self.kv_shared_layer_index].keys.device) # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) - if isinstance(past_key_value, HybridCache) and self.is_sliding: - max_length = past_key_value.sliding_window - indices = ( - slice(0, max_length) - if cache_position.shape[0] > max_length - else cache_position.clamp(min=0, max=max_length - 1) - ) + layer = past_key_value.layers[self.kv_shared_layer_index] + # Device of past layer may be different from current one + indices = cache_position.to(layer.keys.device) + # Sliding window cache layers might have smaller size (for full layers, we never go beyond) + if isinstance(layer, SlidingWindowLayer): + if cache_position.shape[0] > layer.get_max_cache_shape(): + indices = slice(0, layer.get_max_cache_shape()) + else: + indices = indices.clamp(min=0, max=layer.get_max_cache_shape() - 1) # Device of past layer may be different from current one - key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices].to(query_states.device) - value_states = ( - past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices].to(query_states.device) - ) + key_states = layer.keys[:, :, indices].to(query_states.device) + value_states = layer.values[:, :, indices].to(query_states.device) else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 65fcf547b4..43ac57dbb5 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -151,6 +151,52 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase def test_sdpa_padding_matches_padding_free_with_position_ids(self): pass + def test_generation_beyond_sliding_window_tiny_model(self): + """Test generation with a tiny randomly initialised model whose input length is larger than the `sliding_window`. + The model is configured with both `full_attention` and `sliding_attention` layers to make sure the hybrid cache + and mask slicing logic is covered. + """ + config = Gemma3TextConfig.from_pretrained("hf-internal-testing/tiny-random-Gemma3ForCausalLM") + config.attn_implementation = "eager" + config.layer_types = ["full_attention", "sliding_attention"] + config.sliding_window = 8 + config.max_position_embeddings = 128 + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-Gemma3ForCausalLM", config=config + ).to(torch_device) + + input_len = 10 + input_ids = torch.tensor( + [ + [42300, 241087, 255445, 81315, 193760, 184471, 67719, 98191, 210651, 124725], + [102294, 205314, 226646, 62020, 60245, 68025, 251839, 114053, 4695, 175511], + ], + device=torch_device, + ) + attention_mask = torch.ones_like(input_ids).to(torch_device) + with torch.no_grad(): + _ = model.generate( + input_ids, + attention_mask=attention_mask, + max_new_tokens=1, + do_sample=False, + use_cache=True, + cache_implementation="hybrid", + ) + # 2 generations are needed to trigger https://github.com/huggingface/transformers/issues/39711 + # Since it requires model._cache to have been previously initialized + output = model.generate( + input_ids, + attention_mask=attention_mask, + max_new_tokens=5, + do_sample=False, + use_cache=True, + cache_implementation="hybrid", + ) + generated_sequences = output[:, input_len:].cpu() + EXPECTED_OUTPUT = torch.tensor([[90109, 90109, 90109, 83191, 83191], [246901, 69832, 69832, 69832, 62288]]) + torch.testing.assert_close(generated_sequences, EXPECTED_OUTPUT) + class Gemma3Vision2TextModelTester: def __init__( diff --git a/tests/models/gemma3n/test_modeling_gemma3n.py b/tests/models/gemma3n/test_modeling_gemma3n.py index 060bf15ea1..34e474129d 100644 --- a/tests/models/gemma3n/test_modeling_gemma3n.py +++ b/tests/models/gemma3n/test_modeling_gemma3n.py @@ -431,6 +431,11 @@ class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes def test_dola_decoding_sample(self): pass + @pytest.mark.generate + @unittest.skip("Gemma3n does not support QuantizedCache as it performs cache manipulation in the forward pass") + def test_generate_with_quant_cache(self): + pass + class Gemma3nVision2TextModelTester: text_config = {"activation_sparsity_pattern": None}