[Generation, Gemma 3] When passing a custom generation_config, overwrite default values with the model's base generation_config (#36684)

This commit is contained in:
Joao Gante
2025-03-15 12:40:09 +00:00
committed by GitHub
parent f263e88dcf
commit fc8764c9a6
11 changed files with 108 additions and 24 deletions

View File

@@ -16,6 +16,7 @@
import unittest
import pytest
from parameterized import parameterized
from transformers import (
@@ -23,6 +24,7 @@ from transformers import (
AutoTokenizer,
Gemma3Config,
Gemma3TextConfig,
GenerationConfig,
is_torch_available,
)
from transformers.testing_utils import (
@@ -75,6 +77,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@@ -83,6 +86,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@@ -277,6 +281,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@@ -285,6 +290,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@@ -551,3 +557,34 @@ class Gemma3IntegrationTest(unittest.TestCase):
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
def test_generation_beyond_sliding_window_with_generation_config(self):
"""
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
"""
model_id = "gg-hf-g/gemma-3-1b-it"
attn_implementation = "sdpa"
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
).to(torch_device)
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
generation_config = GenerationConfig(max_new_tokens=20)
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
output_text = tokenizer.batch_decode(out)
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)