Generate: refuse to save bad generation config files (#28477)
This commit is contained in:
@@ -551,16 +551,13 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
try:
|
try:
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
self.validate()
|
self.validate()
|
||||||
for w in caught_warnings:
|
if len(caught_warnings) > 0:
|
||||||
raise ValueError(w.message)
|
raise ValueError(str([w.message for w in caught_warnings]))
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
warnings.warn(
|
raise ValueError(
|
||||||
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
|
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
|
||||||
"Fix these issues to save the configuration. This warning will be raised to an exception in v4.34."
|
"Fix these issues to save the configuration.\n\nThrown during validation:\n" + str(exc)
|
||||||
"\n\nThrown during validation:\n" + str(exc),
|
|
||||||
UserWarning,
|
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
|
||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
|
||||||
|
|||||||
@@ -152,14 +152,13 @@ class GenerationConfigTest(unittest.TestCase):
|
|||||||
"""Tests that we refuse to save a generation config that fails validation."""
|
"""Tests that we refuse to save a generation config that fails validation."""
|
||||||
|
|
||||||
# setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that
|
# setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that
|
||||||
# is caught, doesn't save, and raises a warning
|
# is caught, doesn't save, and raises an exception
|
||||||
config = GenerationConfig()
|
config = GenerationConfig()
|
||||||
config.temperature = 0.5
|
config.temperature = 0.5
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
with self.assertRaises(ValueError) as exc:
|
||||||
config.save_pretrained(tmp_dir)
|
config.save_pretrained(tmp_dir)
|
||||||
self.assertEqual(len(captured_warnings), 1)
|
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
|
||||||
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
|
|
||||||
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
||||||
|
|
||||||
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
|
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
|
||||||
@@ -167,13 +166,12 @@ class GenerationConfigTest(unittest.TestCase):
|
|||||||
config = GenerationConfig()
|
config = GenerationConfig()
|
||||||
config.num_return_sequences = 2
|
config.num_return_sequences = 2
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
with self.assertRaises(ValueError) as exc:
|
||||||
config.save_pretrained(tmp_dir)
|
config.save_pretrained(tmp_dir)
|
||||||
self.assertEqual(len(captured_warnings), 1)
|
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
|
||||||
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
|
|
||||||
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
||||||
|
|
||||||
# final check: no warnings thrown if it is correct, and file is saved
|
# final check: no warnings/exceptions thrown if it is correct, and file is saved
|
||||||
config = GenerationConfig()
|
config = GenerationConfig()
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||||
|
|||||||
Reference in New Issue
Block a user