From 94f487626a296deac0022dda6462c0d9f2336106 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 21 Mar 2025 11:01:09 +0000 Subject: [PATCH] [generate] model defaults being inherited only happens for newer models (#36881) --- src/transformers/generation/utils.py | 60 ++++++++++++++------- tests/models/gemma3/test_modeling_gemma3.py | 20 ++++--- 2 files changed, 54 insertions(+), 26 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index aa6c8fb203..9f669e175f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un import numpy as np import torch import torch.distributed as dist +from packaging import version from torch import nn from torch.nn import functional as F @@ -1552,7 +1553,7 @@ class GenerationMixin: return generation_config def _prepare_generation_config( - self, generation_config: Optional[GenerationConfig], **kwargs: Dict + self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict ) -> Tuple[GenerationConfig, Dict]: """ Prepares the base generation config, then applies any generation configuration options from kwargs. This @@ -1591,23 +1592,38 @@ class GenerationMixin: generation_config = copy.deepcopy(generation_config) - # If `generation_config` is provided, let's fallback ALL default values to the model's generation config if not using_model_generation_config: - modified_values = {} - default_generation_config = GenerationConfig() - for key, default_value in default_generation_config.__dict__.items(): - if key.startswith("_"): # metadata - continue - custom_gen_config_value = getattr(generation_config, key) - model_gen_config_value = getattr(self.generation_config, key) - if custom_gen_config_value == default_value and model_gen_config_value != default_value: - modified_values[key] = model_gen_config_value - setattr(generation_config, key, model_gen_config_value) - if len(modified_values) > 0: - logger.warning_once( - f"`generation_config` default values have been modified to match model-specific defaults: " - f"{modified_values}. If this is not desired, please set these values explicitly." - ) + # If `generation_config` is provided: + # - `use_model_defaults`: let's fallback ALL default values to the model's generation config + # - otherwise: legacy behavior, let's just make sure we have the tokens defined + model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version) + if use_model_defaults is True or ( + use_model_defaults is None and model_base_version >= version.parse("4.50.0") + ): + modified_values = {} + default_generation_config = GenerationConfig() + for key, default_value in default_generation_config.__dict__.items(): + if key.startswith("_") or key == "transformers_version": # metadata + continue + custom_gen_config_value = getattr(generation_config, key) + model_gen_config_value = getattr(self.generation_config, key) + if custom_gen_config_value == default_value and model_gen_config_value != default_value: + modified_values[key] = model_gen_config_value + setattr(generation_config, key, model_gen_config_value) + if len(modified_values) > 0: + logger.warning_once( + f"`generation_config` default values have been modified to match model-specific defaults: " + f"{modified_values}. If this is not desired, please set these values explicitly." + ) + else: + if generation_config.bos_token_id is None: + generation_config.bos_token_id = self.generation_config.bos_token_id + if generation_config.eos_token_id is None: + generation_config.eos_token_id = self.generation_config.eos_token_id + if generation_config.pad_token_id is None: + generation_config.pad_token_id = self.generation_config.pad_token_id + if generation_config.decoder_start_token_id is None: + generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id # Finally, apply any passed kwargs model_kwargs = generation_config.update(**kwargs) @@ -1967,6 +1983,7 @@ class GenerationMixin: streamer: Optional["BaseStreamer"] = None, negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, + use_model_defaults: Optional[bool] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" @@ -2031,6 +2048,11 @@ class GenerationMixin: size. This is an experimental feature, subject to breaking API changes in future versions. negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Attention_mask for `negative_prompt_ids`. + use_model_defaults (`bool`, *optional*): + When it is `True`, unset parameters in `generation_config` will be set to the model-specific default + generation configuration (`model.generation_config`), as opposed to the global defaults + (`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be + `True`. kwargs (`Dict[str, Any]`, *optional*): Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder @@ -2058,7 +2080,9 @@ class GenerationMixin: tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation - generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) + generation_config, model_kwargs = self._prepare_generation_config( + generation_config, use_model_defaults, **kwargs + ) self._validate_model_kwargs(model_kwargs.copy()) self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 06a476c69a..7904b3f8eb 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -575,8 +575,8 @@ class Gemma3IntegrationTest(unittest.TestCase): def test_generation_beyond_sliding_window_with_generation_config(self): """ - Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 -- - ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`. + Similar to `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 + -- ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`. """ model_id = "google/gemma-3-1b-it" attn_implementation = "sdpa" @@ -594,12 +594,16 @@ class Gemma3IntegrationTest(unittest.TestCase): # Make sure prefill is larger than sliding window input_size = inputs.input_ids.shape[-1] - self.assertTrue(input_size > model.config.sliding_window) + self.assertGreater(input_size, model.config.sliding_window) - generation_config = GenerationConfig(max_new_tokens=20) + generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5) + out = model.generate(**inputs, generation_config=generation_config) - out = model.generate(**inputs, generation_config=generation_config)[:, input_size:] - output_text = tokenizer.batch_decode(out) + # Generation works beyond sliding window + self.assertGreater(out.shape[1], model.config.sliding_window) + self.assertEqual(out.shape[1], input_size + 5) - EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip - self.assertEqual(output_text, EXPECTED_COMPLETIONS) + # Note: Auto-inheritance only works for models saved starting from 4.50.0 + model.generation_config.transformers_version = "4.49.0" + with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache + out = model.generate(**inputs, generation_config=generation_config)