From fc8764c9a618add64c33e83720f974750bcd0978 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 15 Mar 2025 12:40:09 +0000 Subject: [PATCH] [Generation, Gemma 3] When passing a custom `generation_config`, overwrite default values with the model's base `generation_config` (#36684) --- .../generation/configuration_utils.py | 8 +-- src/transformers/generation/utils.py | 65 ++++++++++++++----- tests/generation/test_utils.py | 2 +- .../aya_vision/test_modeling_aya_vision.py | 3 + tests/models/cohere2/test_modeling_cohere2.py | 3 + tests/models/fuyu/test_modeling_fuyu.py | 3 +- tests/models/gemma2/test_modeling_gemma2.py | 3 + tests/models/gemma3/test_modeling_gemma3.py | 37 +++++++++++ .../paligemma2/test_modeling_paligemma2.py | 3 + .../test_modeling_recurrent_gemma.py | 4 ++ tests/models/smolvlm/test_modeling_smolvlm.py | 1 + 11 files changed, 108 insertions(+), 24 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 6ee48ab3f1..9eba1bfc92 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -73,7 +73,7 @@ if is_torch_available(): } QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} ALL_CACHE_IMPLEMENTATIONS = ( - list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded"] + list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded", "dynamic"] ) @@ -175,6 +175,7 @@ class GenerationConfig(PushToHubMixin): cache_implementation (`str`, *optional*, default to `None`): Name of the cache class that will be instantiated in `generate`, for faster decoding. Possible values are: + - `"dynamic"`: [`DynamicCache`] - `"static"`: [`StaticCache`] - `"offloaded_static"`: [`OffloadedStaticCache`] - `"sliding_window"`: [`SlidingWindowCache`] @@ -182,9 +183,8 @@ class GenerationConfig(PushToHubMixin): - `"mamba"`: [`MambaCache`] - `"quantized"`: [`QuantizedCache`] - We support other cache types, but they must be manually instantiated and - passed to `generate` through the `past_key_values` argument. See our - [cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information. + If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See + our [cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information. cache_config (`CacheConfig` or `dict`, *optional*, default to `None`): Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and it will be converted to its repsective `CacheConfig` internally. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8d5be7d7a0..2e73e423ab 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1177,21 +1177,37 @@ class GenerationMixin: default_list: Union[LogitsProcessorList, StoppingCriteriaList], custom_list: Union[LogitsProcessorList, StoppingCriteriaList], ) -> Union[LogitsProcessorList, StoppingCriteriaList]: + """ + Merge user-defined processors/criteria with the ones instantiated inside `generate`. In case the same + processor/criteria is present on both lists, use the user-defined one. + + (Note: up to v4.49.0, this funtion threw an exception is the same logit processor was found twice.) + """ if len(custom_list) == 0: return default_list + + final_list = type(default_list)() for default in default_list: + using_custom = False for custom in custom_list: if type(custom) is type(default): object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" - raise ValueError( - f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" - f" `.generate()`, but it has already been created with the values {default}. {default} has been" - " created by passing the corresponding arguments to generate or by the model's config default" - f" values. If you just want to change the default values of {object_type} consider passing" - f" them as arguments to `.generate()` instead of using a custom {object_type}." + logger.warning_once( + f"A custom {object_type} of type {type(custom)} has been passed to `.generate()`, but it " + f"was also created in `.generate()`, given its parameterization. The custom {type(custom)} " + f"will take precedence. Please check the docstring of {type(custom)} to see related " + "`.generate()` flags." ) - default_list.extend(custom_list) - return default_list + final_list.append(custom) + using_custom = True + break + if not using_custom: + final_list.append(default) + + for custom in custom_list: + if custom not in final_list: + final_list.append(custom) + return final_list def compute_transition_scores( self, @@ -1573,17 +1589,28 @@ class GenerationMixin: # exception will be raised in `_validate_model_kwargs` if not is_torchdynamo_compiling(): generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model + + # If `generation_config` is provided, let's fallback ALL default values to the model's generation config + # TODO (joao): per-model generation config classes. if not using_model_generation_config: - if generation_config.bos_token_id is None: - generation_config.bos_token_id = self.generation_config.bos_token_id - if generation_config.eos_token_id is None: - generation_config.eos_token_id = self.generation_config.eos_token_id - if generation_config.pad_token_id is None: - generation_config.pad_token_id = self.generation_config.pad_token_id - if generation_config.decoder_start_token_id is None: - generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id + modified_values = {} + default_generation_config = GenerationConfig() + for key, default_value in default_generation_config.__dict__.items(): + if key.startswith("_"): # metadata + continue + custom_gen_config_value = getattr(generation_config, key) + model_gen_config_value = getattr(self.generation_config, key) + if custom_gen_config_value == default_value and model_gen_config_value != default_value: + modified_values[key] = model_gen_config_value + setattr(generation_config, key, model_gen_config_value) + if len(modified_values) > 0: + logger.warning_once( + f"`generation_config` default values have been modified to match model-specific defaults: " + f"{modified_values}. If this is not desired, please set these values explicitly." + ) + + # Finally, apply any passed kwargs + model_kwargs = generation_config.update(**kwargs) else: model_kwargs = kwargs @@ -1837,6 +1864,8 @@ class GenerationMixin: model_kwargs[cache_name] = cache_class(cache_config) elif generation_config.cache_implementation == "offloaded": model_kwargs[cache_name] = OffloadedCache() + elif generation_config.cache_implementation == "dynamic": + model_kwargs[cache_name] = DynamicCache() # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # keeps copying the cache thus using much more memory diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index df58c7fc5c..e6cbe5267a 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1162,8 +1162,8 @@ class GenerationTesterMixin: # The two outputs must match and their shape must be as expected self._check_similar_generate_outputs(low_output, high_output) - @pytest.mark.generate @parameterized.expand([("random",), ("same",)]) + @pytest.mark.generate def test_assisted_decoding_matches_greedy_search(self, assistant_type): # This test ensures that the assisted generation does not introduce output changes over greedy search. # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info. diff --git a/tests/models/aya_vision/test_modeling_aya_vision.py b/tests/models/aya_vision/test_modeling_aya_vision.py index d7d2bd9183..6d42d4e98e 100644 --- a/tests/models/aya_vision/test_modeling_aya_vision.py +++ b/tests/models/aya_vision/test_modeling_aya_vision.py @@ -16,6 +16,7 @@ import unittest +import pytest from parameterized import parameterized from transformers import ( @@ -261,6 +262,7 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester pass @parameterized.expand([("random",), ("same",)]) + @pytest.mark.generate @unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_matches_greedy_search(self, assistant_type): pass @@ -269,6 +271,7 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): pass + @pytest.mark.generate @unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_sample(self): pass diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py index 418cf50ada..699eb15ddb 100644 --- a/tests/models/cohere2/test_modeling_cohere2.py +++ b/tests/models/cohere2/test_modeling_cohere2.py @@ -16,6 +16,7 @@ import unittest +import pytest from packaging import version from parameterized import parameterized from pytest import mark @@ -81,6 +82,7 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase): pass @parameterized.expand([("random",), ("same",)]) + @pytest.mark.generate @unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_matches_greedy_search(self, assistant_type): pass @@ -89,6 +91,7 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase): def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): pass + @pytest.mark.generate @unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_sample(self): pass diff --git a/tests/models/fuyu/test_modeling_fuyu.py b/tests/models/fuyu/test_modeling_fuyu.py index a908567d24..835b208c7d 100644 --- a/tests/models/fuyu/test_modeling_fuyu.py +++ b/tests/models/fuyu/test_modeling_fuyu.py @@ -299,12 +299,13 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @pytest.mark.generate @parameterized.expand([("random",), ("same",)]) + @pytest.mark.generate @unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices") def test_assisted_decoding_matches_greedy_search(self): pass + @pytest.mark.generate @unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices") def test_assisted_decoding_sample(self): pass diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index d6db7079cb..3a51e0bbf7 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -16,6 +16,7 @@ import unittest +import pytest from packaging import version from parameterized import parameterized from pytest import mark @@ -96,6 +97,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): pass @parameterized.expand([("random",), ("same",)]) + @pytest.mark.generate @unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_matches_greedy_search(self, assistant_type): pass @@ -104,6 +106,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): pass + @pytest.mark.generate @unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_sample(self): pass diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index f9b7ad1003..3586d35cb3 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -16,6 +16,7 @@ import unittest +import pytest from parameterized import parameterized from transformers import ( @@ -23,6 +24,7 @@ from transformers import ( AutoTokenizer, Gemma3Config, Gemma3TextConfig, + GenerationConfig, is_torch_available, ) from transformers.testing_utils import ( @@ -75,6 +77,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase pass @parameterized.expand([("random",), ("same",)]) + @pytest.mark.generate @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_matches_greedy_search(self, assistant_type): pass @@ -83,6 +86,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): pass + @pytest.mark.generate @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_sample(self): pass @@ -277,6 +281,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte pass @parameterized.expand([("random",), ("same",)]) + @pytest.mark.generate @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_matches_greedy_search(self, assistant_type): pass @@ -285,6 +290,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): pass + @pytest.mark.generate @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_sample(self): pass @@ -551,3 +557,34 @@ class Gemma3IntegrationTest(unittest.TestCase): EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip self.assertEqual(output_text, EXPECTED_COMPLETIONS) + + def test_generation_beyond_sliding_window_with_generation_config(self): + """ + Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 -- + ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`. + """ + model_id = "gg-hf-g/gemma-3-1b-it" + attn_implementation = "sdpa" + + input_text = [ + "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens + "A list of colors: red, blue", # This will almost all be padding tokens + ] + tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") + inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device) + + model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16 + ).to(torch_device) + + # Make sure prefill is larger than sliding window + input_size = inputs.input_ids.shape[-1] + self.assertTrue(input_size > model.config.sliding_window) + + generation_config = GenerationConfig(max_new_tokens=20) + + out = model.generate(**inputs, generation_config=generation_config)[:, input_size:] + output_text = tokenizer.batch_decode(out) + + EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip + self.assertEqual(output_text, EXPECTED_COMPLETIONS) diff --git a/tests/models/paligemma2/test_modeling_paligemma2.py b/tests/models/paligemma2/test_modeling_paligemma2.py index a905d48350..cd159e750d 100644 --- a/tests/models/paligemma2/test_modeling_paligemma2.py +++ b/tests/models/paligemma2/test_modeling_paligemma2.py @@ -16,6 +16,7 @@ import unittest +import pytest from parameterized import parameterized from transformers import ( @@ -351,6 +352,7 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe pass @parameterized.expand([("random",), ("same",)]) + @pytest.mark.generate @unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_matches_greedy_search(self, assistant_type): pass @@ -359,6 +361,7 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): pass + @pytest.mark.generate @unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_sample(self): pass diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py index 4d09757802..a7a8a74653 100644 --- a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py +++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py @@ -16,6 +16,8 @@ import unittest +import pytest + from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed from transformers.testing_utils import ( require_bitsandbytes, @@ -375,6 +377,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te def test_model_parallel_beam_search(self): pass + @pytest.mark.generate @unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported") def test_assisted_decoding_matches_greedy_search(self): pass @@ -383,6 +386,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te def test_left_padding_compatibility(self): pass + @pytest.mark.generate @unittest.skip(reason="Relies on `past_key_values` returned by the model. Not supported with recurrent gemma") def test_assisted_decoding_sample(self): pass diff --git a/tests/models/smolvlm/test_modeling_smolvlm.py b/tests/models/smolvlm/test_modeling_smolvlm.py index 11a3569f06..050b198565 100644 --- a/tests/models/smolvlm/test_modeling_smolvlm.py +++ b/tests/models/smolvlm/test_modeling_smolvlm.py @@ -423,6 +423,7 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste pass @parameterized.expand([("random",), ("same",)]) + @pytest.mark.generate @unittest.skip(reason="Cache position is off by one leaving out image tokens, FIXME raushan") def test_assisted_decoding_matches_greedy_search(self, assistant_type): pass