[generate] model defaults being inherited only happens for newer models (#36881)
This commit is contained in:
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
@@ -1552,7 +1553,7 @@ class GenerationMixin:
|
|||||||
return generation_config
|
return generation_config
|
||||||
|
|
||||||
def _prepare_generation_config(
|
def _prepare_generation_config(
|
||||||
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
|
self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict
|
||||||
) -> Tuple[GenerationConfig, Dict]:
|
) -> Tuple[GenerationConfig, Dict]:
|
||||||
"""
|
"""
|
||||||
Prepares the base generation config, then applies any generation configuration options from kwargs. This
|
Prepares the base generation config, then applies any generation configuration options from kwargs. This
|
||||||
@@ -1591,23 +1592,38 @@ class GenerationMixin:
|
|||||||
|
|
||||||
generation_config = copy.deepcopy(generation_config)
|
generation_config = copy.deepcopy(generation_config)
|
||||||
|
|
||||||
# If `generation_config` is provided, let's fallback ALL default values to the model's generation config
|
|
||||||
if not using_model_generation_config:
|
if not using_model_generation_config:
|
||||||
modified_values = {}
|
# If `generation_config` is provided:
|
||||||
default_generation_config = GenerationConfig()
|
# - `use_model_defaults`: let's fallback ALL default values to the model's generation config
|
||||||
for key, default_value in default_generation_config.__dict__.items():
|
# - otherwise: legacy behavior, let's just make sure we have the tokens defined
|
||||||
if key.startswith("_"): # metadata
|
model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version)
|
||||||
continue
|
if use_model_defaults is True or (
|
||||||
custom_gen_config_value = getattr(generation_config, key)
|
use_model_defaults is None and model_base_version >= version.parse("4.50.0")
|
||||||
model_gen_config_value = getattr(self.generation_config, key)
|
):
|
||||||
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
|
modified_values = {}
|
||||||
modified_values[key] = model_gen_config_value
|
default_generation_config = GenerationConfig()
|
||||||
setattr(generation_config, key, model_gen_config_value)
|
for key, default_value in default_generation_config.__dict__.items():
|
||||||
if len(modified_values) > 0:
|
if key.startswith("_") or key == "transformers_version": # metadata
|
||||||
logger.warning_once(
|
continue
|
||||||
f"`generation_config` default values have been modified to match model-specific defaults: "
|
custom_gen_config_value = getattr(generation_config, key)
|
||||||
f"{modified_values}. If this is not desired, please set these values explicitly."
|
model_gen_config_value = getattr(self.generation_config, key)
|
||||||
)
|
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
|
||||||
|
modified_values[key] = model_gen_config_value
|
||||||
|
setattr(generation_config, key, model_gen_config_value)
|
||||||
|
if len(modified_values) > 0:
|
||||||
|
logger.warning_once(
|
||||||
|
f"`generation_config` default values have been modified to match model-specific defaults: "
|
||||||
|
f"{modified_values}. If this is not desired, please set these values explicitly."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if generation_config.bos_token_id is None:
|
||||||
|
generation_config.bos_token_id = self.generation_config.bos_token_id
|
||||||
|
if generation_config.eos_token_id is None:
|
||||||
|
generation_config.eos_token_id = self.generation_config.eos_token_id
|
||||||
|
if generation_config.pad_token_id is None:
|
||||||
|
generation_config.pad_token_id = self.generation_config.pad_token_id
|
||||||
|
if generation_config.decoder_start_token_id is None:
|
||||||
|
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
|
||||||
|
|
||||||
# Finally, apply any passed kwargs
|
# Finally, apply any passed kwargs
|
||||||
model_kwargs = generation_config.update(**kwargs)
|
model_kwargs = generation_config.update(**kwargs)
|
||||||
@@ -1967,6 +1983,7 @@ class GenerationMixin:
|
|||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_model_defaults: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -2031,6 +2048,11 @@ class GenerationMixin:
|
|||||||
size. This is an experimental feature, subject to breaking API changes in future versions.
|
size. This is an experimental feature, subject to breaking API changes in future versions.
|
||||||
negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Attention_mask for `negative_prompt_ids`.
|
Attention_mask for `negative_prompt_ids`.
|
||||||
|
use_model_defaults (`bool`, *optional*):
|
||||||
|
When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
|
||||||
|
generation configuration (`model.generation_config`), as opposed to the global defaults
|
||||||
|
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
|
||||||
|
`True`.
|
||||||
kwargs (`Dict[str, Any]`, *optional*):
|
kwargs (`Dict[str, Any]`, *optional*):
|
||||||
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
|
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
|
||||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||||
@@ -2058,7 +2080,9 @@ class GenerationMixin:
|
|||||||
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
||||||
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
|
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
|
||||||
|
|
||||||
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
generation_config, model_kwargs = self._prepare_generation_config(
|
||||||
|
generation_config, use_model_defaults, **kwargs
|
||||||
|
)
|
||||||
self._validate_model_kwargs(model_kwargs.copy())
|
self._validate_model_kwargs(model_kwargs.copy())
|
||||||
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
|
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
|
||||||
|
|
||||||
|
|||||||
@@ -575,8 +575,8 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_generation_beyond_sliding_window_with_generation_config(self):
|
def test_generation_beyond_sliding_window_with_generation_config(self):
|
||||||
"""
|
"""
|
||||||
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
|
Similar to `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684
|
||||||
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
|
-- ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
|
||||||
"""
|
"""
|
||||||
model_id = "google/gemma-3-1b-it"
|
model_id = "google/gemma-3-1b-it"
|
||||||
attn_implementation = "sdpa"
|
attn_implementation = "sdpa"
|
||||||
@@ -594,12 +594,16 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Make sure prefill is larger than sliding window
|
# Make sure prefill is larger than sliding window
|
||||||
input_size = inputs.input_ids.shape[-1]
|
input_size = inputs.input_ids.shape[-1]
|
||||||
self.assertTrue(input_size > model.config.sliding_window)
|
self.assertGreater(input_size, model.config.sliding_window)
|
||||||
|
|
||||||
generation_config = GenerationConfig(max_new_tokens=20)
|
generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5)
|
||||||
|
out = model.generate(**inputs, generation_config=generation_config)
|
||||||
|
|
||||||
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
|
# Generation works beyond sliding window
|
||||||
output_text = tokenizer.batch_decode(out)
|
self.assertGreater(out.shape[1], model.config.sliding_window)
|
||||||
|
self.assertEqual(out.shape[1], input_size + 5)
|
||||||
|
|
||||||
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
# Note: Auto-inheritance only works for models saved starting from 4.50.0
|
||||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
model.generation_config.transformers_version = "4.49.0"
|
||||||
|
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
|
||||||
|
out = model.generate(**inputs, generation_config=generation_config)
|
||||||
|
|||||||
Reference in New Issue
Block a user