[generate] model defaults being inherited only happens for newer models (#36881)

This commit is contained in:
Joao Gante
2025-03-21 11:01:09 +00:00
committed by GitHub
parent f19d018bff
commit 94f487626a
2 changed files with 54 additions and 26 deletions

View File

@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from packaging import version
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
@@ -1552,7 +1553,7 @@ class GenerationMixin:
return generation_config return generation_config
def _prepare_generation_config( def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], **kwargs: Dict self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]: ) -> Tuple[GenerationConfig, Dict]:
""" """
Prepares the base generation config, then applies any generation configuration options from kwargs. This Prepares the base generation config, then applies any generation configuration options from kwargs. This
@@ -1591,12 +1592,18 @@ class GenerationMixin:
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
# If `generation_config` is provided, let's fallback ALL default values to the model's generation config
if not using_model_generation_config: if not using_model_generation_config:
# If `generation_config` is provided:
# - `use_model_defaults`: let's fallback ALL default values to the model's generation config
# - otherwise: legacy behavior, let's just make sure we have the tokens defined
model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version)
if use_model_defaults is True or (
use_model_defaults is None and model_base_version >= version.parse("4.50.0")
):
modified_values = {} modified_values = {}
default_generation_config = GenerationConfig() default_generation_config = GenerationConfig()
for key, default_value in default_generation_config.__dict__.items(): for key, default_value in default_generation_config.__dict__.items():
if key.startswith("_"): # metadata if key.startswith("_") or key == "transformers_version": # metadata
continue continue
custom_gen_config_value = getattr(generation_config, key) custom_gen_config_value = getattr(generation_config, key)
model_gen_config_value = getattr(self.generation_config, key) model_gen_config_value = getattr(self.generation_config, key)
@@ -1608,6 +1615,15 @@ class GenerationMixin:
f"`generation_config` default values have been modified to match model-specific defaults: " f"`generation_config` default values have been modified to match model-specific defaults: "
f"{modified_values}. If this is not desired, please set these values explicitly." f"{modified_values}. If this is not desired, please set these values explicitly."
) )
else:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.decoder_start_token_id is None:
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
# Finally, apply any passed kwargs # Finally, apply any passed kwargs
model_kwargs = generation_config.update(**kwargs) model_kwargs = generation_config.update(**kwargs)
@@ -1967,6 +1983,7 @@ class GenerationMixin:
streamer: Optional["BaseStreamer"] = None, streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None,
use_model_defaults: Optional[bool] = None,
**kwargs, **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]: ) -> Union[GenerateOutput, torch.LongTensor]:
r""" r"""
@@ -2031,6 +2048,11 @@ class GenerationMixin:
size. This is an experimental feature, subject to breaking API changes in future versions. size. This is an experimental feature, subject to breaking API changes in future versions.
negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Attention_mask for `negative_prompt_ids`. Attention_mask for `negative_prompt_ids`.
use_model_defaults (`bool`, *optional*):
When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
generation configuration (`model.generation_config`), as opposed to the global defaults
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
`True`.
kwargs (`Dict[str, Any]`, *optional*): kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
@@ -2058,7 +2080,9 @@ class GenerationMixin:
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) generation_config, model_kwargs = self._prepare_generation_config(
generation_config, use_model_defaults, **kwargs
)
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)

View File

@@ -575,8 +575,8 @@ class Gemma3IntegrationTest(unittest.TestCase):
def test_generation_beyond_sliding_window_with_generation_config(self): def test_generation_beyond_sliding_window_with_generation_config(self):
""" """
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 -- Similar to `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`. -- ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
""" """
model_id = "google/gemma-3-1b-it" model_id = "google/gemma-3-1b-it"
attn_implementation = "sdpa" attn_implementation = "sdpa"
@@ -594,12 +594,16 @@ class Gemma3IntegrationTest(unittest.TestCase):
# Make sure prefill is larger than sliding window # Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1] input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window) self.assertGreater(input_size, model.config.sliding_window)
generation_config = GenerationConfig(max_new_tokens=20) generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5)
out = model.generate(**inputs, generation_config=generation_config)
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:] # Generation works beyond sliding window
output_text = tokenizer.batch_decode(out) self.assertGreater(out.shape[1], model.config.sliding_window)
self.assertEqual(out.shape[1], input_size + 5)
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip # Note: Auto-inheritance only works for models saved starting from 4.50.0
self.assertEqual(output_text, EXPECTED_COMPLETIONS) model.generation_config.transformers_version = "4.49.0"
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
out = model.generate(**inputs, generation_config=generation_config)