Generate: throw warning when return_dict_in_generate is False but should be True (#33146)
This commit is contained in:
@@ -288,7 +288,9 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
Whether or not to return the unprocessed prediction logit scores. See `logits` under returned tensors for
|
Whether or not to return the unprocessed prediction logit scores. See `logits` under returned tensors for
|
||||||
more details.
|
more details.
|
||||||
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`], as opposed to returning exclusively the generated
|
||||||
|
sequence. This flag must be set to `True` to return the generation cache (when `use_cache` is `True`)
|
||||||
|
or optional outputs (see flags starting with `output_`)
|
||||||
|
|
||||||
> Special tokens that can be used at generation time
|
> Special tokens that can be used at generation time
|
||||||
|
|
||||||
@@ -334,6 +336,8 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
present in `generate`'s signature will be used in the model forward pass.
|
present in `generate`'s signature will be used in the model forward pass.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits")
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
# Parameters that control the length of the output
|
# Parameters that control the length of the output
|
||||||
self.max_length = kwargs.pop("max_length", 20)
|
self.max_length = kwargs.pop("max_length", 20)
|
||||||
@@ -727,7 +731,17 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
||||||
self.watermarking_config.validate()
|
self.watermarking_config.validate()
|
||||||
|
|
||||||
# 7. check common issue: passing `generate` arguments inside the generation config
|
# 7. other incorrect combinations
|
||||||
|
if self.return_dict_in_generate is not True:
|
||||||
|
for extra_output_flag in self.extra_output_flags:
|
||||||
|
if getattr(self, extra_output_flag) is True:
|
||||||
|
warnings.warn(
|
||||||
|
f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When "
|
||||||
|
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. check common issue: passing `generate` arguments inside the generation config
|
||||||
generate_arguments = (
|
generate_arguments = (
|
||||||
"logits_processor",
|
"logits_processor",
|
||||||
"stopping_criteria",
|
"stopping_criteria",
|
||||||
@@ -786,7 +800,8 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
|
|
||||||
if use_auth_token is not None:
|
if use_auth_token is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. "
|
||||||
|
"Please use `token` instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
if kwargs.get("token", None) is not None:
|
if kwargs.get("token", None) is not None:
|
||||||
@@ -1189,6 +1204,11 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
|
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
|
||||||
setattr(config, attr, decoder_config[attr])
|
setattr(config, attr, decoder_config[attr])
|
||||||
|
|
||||||
|
# If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
|
||||||
|
if config.return_dict_in_generate is False:
|
||||||
|
if any(getattr(config, extra_output_flag, False) for extra_output_flag in config.extra_output_flags):
|
||||||
|
config.return_dict_in_generate = True
|
||||||
|
|
||||||
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
|
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|||||||
@@ -136,6 +136,10 @@ class GenerationConfigTest(unittest.TestCase):
|
|||||||
GenerationConfig(do_sample=False, temperature=0.5)
|
GenerationConfig(do_sample=False, temperature=0.5)
|
||||||
self.assertEqual(len(captured_warnings), 1)
|
self.assertEqual(len(captured_warnings), 1)
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||||
|
GenerationConfig(return_dict_in_generate=False, output_scores=True)
|
||||||
|
self.assertEqual(len(captured_warnings), 1)
|
||||||
|
|
||||||
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
|
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
|
||||||
# that is done by unsetting the parameter (i.e. setting it to None)
|
# that is done by unsetting the parameter (i.e. setting it to None)
|
||||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
||||||
|
|||||||
Reference in New Issue
Block a user