[Generation, Gemma 3] When passing a custom generation_config, overwrite default values with the model's base generation_config (#36684)
This commit is contained in:
@@ -1162,8 +1162,8 @@ class GenerationTesterMixin:
|
||||
# The two outputs must match and their shape must be as expected
|
||||
self._check_similar_generate_outputs(low_output, high_output)
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
@@ -261,6 +262,7 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
@@ -269,6 +271,7 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
@@ -81,6 +82,7 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
@@ -89,6 +91,7 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@@ -299,12 +299,13 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
|
||||
def test_assisted_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
@@ -96,6 +97,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
@@ -104,6 +106,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
@@ -23,6 +24,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
Gemma3Config,
|
||||
Gemma3TextConfig,
|
||||
GenerationConfig,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
@@ -75,6 +77,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
@@ -83,6 +86,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
@@ -277,6 +281,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
@@ -285,6 +290,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
@@ -551,3 +557,34 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
|
||||
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||
|
||||
def test_generation_beyond_sliding_window_with_generation_config(self):
|
||||
"""
|
||||
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
|
||||
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
|
||||
"""
|
||||
model_id = "gg-hf-g/gemma-3-1b-it"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
input_text = [
|
||||
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||
"A list of colors: red, blue", # This will almost all be padding tokens
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
||||
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
# Make sure prefill is larger than sliding window
|
||||
input_size = inputs.input_ids.shape[-1]
|
||||
self.assertTrue(input_size > model.config.sliding_window)
|
||||
|
||||
generation_config = GenerationConfig(max_new_tokens=20)
|
||||
|
||||
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
|
||||
output_text = tokenizer.batch_decode(out)
|
||||
|
||||
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
@@ -351,6 +352,7 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
@@ -359,6 +361,7 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
@@ -375,6 +377,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
def test_model_parallel_beam_search(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
|
||||
def test_assisted_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
@@ -383,6 +386,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
def test_left_padding_compatibility(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Relies on `past_key_values` returned by the model. Not supported with recurrent gemma")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@@ -423,6 +423,7 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Cache position is off by one leaving out image tokens, FIXME raushan")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user