[Generation, Gemma 3] When passing a custom generation_config, overwrite default values with the model's base generation_config (#36684)

This commit is contained in:
Joao Gante
2025-03-15 12:40:09 +00:00
committed by GitHub
parent f263e88dcf
commit fc8764c9a6
11 changed files with 108 additions and 24 deletions

View File

@@ -1162,8 +1162,8 @@ class GenerationTesterMixin:
# The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(low_output, high_output)
@pytest.mark.generate
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.

View File

@@ -16,6 +16,7 @@
import unittest
import pytest
from parameterized import parameterized
from transformers import (
@@ -261,6 +262,7 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@@ -269,6 +271,7 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass

View File

@@ -16,6 +16,7 @@
import unittest
import pytest
from packaging import version
from parameterized import parameterized
from pytest import mark
@@ -81,6 +82,7 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@@ -89,6 +91,7 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass

View File

@@ -299,12 +299,13 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@pytest.mark.generate
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
def test_assisted_decoding_matches_greedy_search(self):
pass
@pytest.mark.generate
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
def test_assisted_decoding_sample(self):
pass

View File

@@ -16,6 +16,7 @@
import unittest
import pytest
from packaging import version
from parameterized import parameterized
from pytest import mark
@@ -96,6 +97,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@@ -104,6 +106,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass

View File

@@ -16,6 +16,7 @@
import unittest
import pytest
from parameterized import parameterized
from transformers import (
@@ -23,6 +24,7 @@ from transformers import (
AutoTokenizer,
Gemma3Config,
Gemma3TextConfig,
GenerationConfig,
is_torch_available,
)
from transformers.testing_utils import (
@@ -75,6 +77,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@@ -83,6 +86,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@@ -277,6 +281,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@@ -285,6 +290,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@@ -551,3 +557,34 @@ class Gemma3IntegrationTest(unittest.TestCase):
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
def test_generation_beyond_sliding_window_with_generation_config(self):
"""
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
"""
model_id = "gg-hf-g/gemma-3-1b-it"
attn_implementation = "sdpa"
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
).to(torch_device)
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
generation_config = GenerationConfig(max_new_tokens=20)
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
output_text = tokenizer.batch_decode(out)
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)

View File

@@ -16,6 +16,7 @@
import unittest
import pytest
from parameterized import parameterized
from transformers import (
@@ -351,6 +352,7 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@@ -359,6 +361,7 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass

View File

@@ -16,6 +16,8 @@
import unittest
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
@@ -375,6 +377,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
def test_model_parallel_beam_search(self):
pass
@pytest.mark.generate
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
def test_assisted_decoding_matches_greedy_search(self):
pass
@@ -383,6 +386,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
def test_left_padding_compatibility(self):
pass
@pytest.mark.generate
@unittest.skip(reason="Relies on `past_key_values` returned by the model. Not supported with recurrent gemma")
def test_assisted_decoding_sample(self):
pass

View File

@@ -423,6 +423,7 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip(reason="Cache position is off by one leaving out image tokens, FIXME raushan")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass