Fix Cache.max_cache_len max value for Hybrid models (#39737)

* fix gemma

* fix min

* fix quant init issue

* fix gemma 3n

* skip quant cache test

* fix modular

* new test for Gemma

* include cyril change

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
Manuel de Prada Corral
2025-07-29 17:12:50 +02:00
committed by GitHub
parent 075dbbceaa
commit c4e2069898
5 changed files with 82 additions and 40 deletions

View File

@@ -151,6 +151,52 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
def test_generation_beyond_sliding_window_tiny_model(self):
"""Test generation with a tiny randomly initialised model whose input length is larger than the `sliding_window`.
The model is configured with both `full_attention` and `sliding_attention` layers to make sure the hybrid cache
and mask slicing logic is covered.
"""
config = Gemma3TextConfig.from_pretrained("hf-internal-testing/tiny-random-Gemma3ForCausalLM")
config.attn_implementation = "eager"
config.layer_types = ["full_attention", "sliding_attention"]
config.sliding_window = 8
config.max_position_embeddings = 128
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-Gemma3ForCausalLM", config=config
).to(torch_device)
input_len = 10
input_ids = torch.tensor(
[
[42300, 241087, 255445, 81315, 193760, 184471, 67719, 98191, 210651, 124725],
[102294, 205314, 226646, 62020, 60245, 68025, 251839, 114053, 4695, 175511],
],
device=torch_device,
)
attention_mask = torch.ones_like(input_ids).to(torch_device)
with torch.no_grad():
_ = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=1,
do_sample=False,
use_cache=True,
cache_implementation="hybrid",
)
# 2 generations are needed to trigger https://github.com/huggingface/transformers/issues/39711
# Since it requires model._cache to have been previously initialized
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=5,
do_sample=False,
use_cache=True,
cache_implementation="hybrid",
)
generated_sequences = output[:, input_len:].cpu()
EXPECTED_OUTPUT = torch.tensor([[90109, 90109, 90109, 83191, 83191], [246901, 69832, 69832, 69832, 62288]])
torch.testing.assert_close(generated_sequences, EXPECTED_OUTPUT)
class Gemma3Vision2TextModelTester:
def __init__(