Generation: strict generation config validation at save time (#25411)
* strict gen config save; Add tests * add note that the warning will be an exception in v4.34
This commit is contained in:
@@ -354,8 +354,8 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
# 1. detect sampling-only parameterization when not in sampling mode
|
# 1. detect sampling-only parameterization when not in sampling mode
|
||||||
if self.do_sample is False:
|
if self.do_sample is False:
|
||||||
greedy_wrong_parameter_msg = (
|
greedy_wrong_parameter_msg = (
|
||||||
"`do_sample` is set to `False`. However, {flag_name} is set to {flag_value} -- this flag is only used "
|
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
|
||||||
"in sample-based generation modes. You should set `do_sample=True` or unset {flag_name}."
|
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
|
||||||
+ fix_location
|
+ fix_location
|
||||||
)
|
)
|
||||||
if self.temperature != 1.0:
|
if self.temperature != 1.0:
|
||||||
@@ -392,8 +392,8 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
# 2. detect beam-only parameterization when not in beam mode
|
# 2. detect beam-only parameterization when not in beam mode
|
||||||
if self.num_beams == 1:
|
if self.num_beams == 1:
|
||||||
single_beam_wrong_parameter_msg = (
|
single_beam_wrong_parameter_msg = (
|
||||||
"`num_beams` is set to 1. However, {flag_name} is set to {flag_value} -- this flag is only used in "
|
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
|
||||||
"beam-based generation modes. You should set `num_beams>1` or unset {flag_name}." + fix_location
|
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location
|
||||||
)
|
)
|
||||||
if self.early_stopping is not False:
|
if self.early_stopping is not False:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@@ -430,9 +430,9 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
# constrained beam search
|
# constrained beam search
|
||||||
if self.constraints is not None:
|
if self.constraints is not None:
|
||||||
constrained_wrong_parameter_msg = (
|
constrained_wrong_parameter_msg = (
|
||||||
"`constraints` is not `None`, triggering constrained beam search. However, {flag_name} is set to "
|
"`constraints` is not `None`, triggering constrained beam search. However, `{flag_name}` is set "
|
||||||
"{flag_value}, which is incompatible with this generation mode. Set `constraints=None` or unset "
|
"to `{flag_value}`, which is incompatible with this generation mode. Set `constraints=None` or "
|
||||||
"{flag_name} to continue." + fix_location
|
"unset `{flag_name}` to continue." + fix_location
|
||||||
)
|
)
|
||||||
if self.do_sample is True:
|
if self.do_sample is True:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -497,6 +497,22 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
kwargs (`Dict[str, Any]`, *optional*):
|
kwargs (`Dict[str, Any]`, *optional*):
|
||||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance
|
||||||
|
try:
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
self.validate()
|
||||||
|
for w in caught_warnings:
|
||||||
|
raise ValueError(w.message)
|
||||||
|
except ValueError as exc:
|
||||||
|
warnings.warn(
|
||||||
|
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
|
||||||
|
"Fix these issues to save the configuration. This warning will be raised to an exception in v4.34."
|
||||||
|
"\n\nThrown during validation:\n" + str(exc),
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
|
||||||
if use_auth_token is not None:
|
if use_auth_token is not None:
|
||||||
|
|||||||
@@ -14,8 +14,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
import warnings
|
||||||
|
|
||||||
from huggingface_hub import HfFolder, delete_repo
|
from huggingface_hub import HfFolder, delete_repo
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
@@ -118,6 +120,39 @@ class GenerationConfigTest(unittest.TestCase):
|
|||||||
self.assertEqual(loaded_config.do_sample, True)
|
self.assertEqual(loaded_config.do_sample, True)
|
||||||
self.assertEqual(loaded_config.num_beams, 1) # default value
|
self.assertEqual(loaded_config.num_beams, 1) # default value
|
||||||
|
|
||||||
|
def test_refuse_to_save(self):
|
||||||
|
"""Tests that we refuse to save a generation config that fails validation."""
|
||||||
|
|
||||||
|
# setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that
|
||||||
|
# is caught, doesn't save, and raises a warning
|
||||||
|
config = GenerationConfig()
|
||||||
|
config.temperature = 0.5
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||||
|
config.save_pretrained(tmp_dir)
|
||||||
|
self.assertEqual(len(captured_warnings), 1)
|
||||||
|
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
|
||||||
|
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
||||||
|
|
||||||
|
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
|
||||||
|
# caught, doesn't save, and raises a warning
|
||||||
|
config = GenerationConfig()
|
||||||
|
config.num_return_sequences = 2
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||||
|
config.save_pretrained(tmp_dir)
|
||||||
|
self.assertEqual(len(captured_warnings), 1)
|
||||||
|
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
|
||||||
|
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
||||||
|
|
||||||
|
# final check: no warnings thrown if it is correct, and file is saved
|
||||||
|
config = GenerationConfig()
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||||
|
config.save_pretrained(tmp_dir)
|
||||||
|
self.assertEqual(len(captured_warnings), 0)
|
||||||
|
self.assertTrue(len(os.listdir(tmp_dir)) == 1)
|
||||||
|
|
||||||
|
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
class ConfigPushToHubTester(unittest.TestCase):
|
class ConfigPushToHubTester(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user