From afc45b13ca40f56268e5f135aab2487377fc536b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 12 Jan 2024 16:01:17 +0000 Subject: [PATCH] Generate: refuse to save bad generation config files (#28477) --- src/transformers/generation/configuration_utils.py | 11 ++++------- tests/generation/test_configuration_utils.py | 14 ++++++-------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 4818ca8d97..21fe916a7a 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -551,16 +551,13 @@ class GenerationConfig(PushToHubMixin): try: with warnings.catch_warnings(record=True) as caught_warnings: self.validate() - for w in caught_warnings: - raise ValueError(w.message) + if len(caught_warnings) > 0: + raise ValueError(str([w.message for w in caught_warnings])) except ValueError as exc: - warnings.warn( + raise ValueError( "The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. " - "Fix these issues to save the configuration. This warning will be raised to an exception in v4.34." - "\n\nThrown during validation:\n" + str(exc), - UserWarning, + "Fix these issues to save the configuration.\n\nThrown during validation:\n" + str(exc) ) - return use_auth_token = kwargs.pop("use_auth_token", None) diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index e5eb1bb34c..dc69a673ef 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -152,14 +152,13 @@ class GenerationConfigTest(unittest.TestCase): """Tests that we refuse to save a generation config that fails validation.""" # setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that - # is caught, doesn't save, and raises a warning + # is caught, doesn't save, and raises an exception config = GenerationConfig() config.temperature = 0.5 with tempfile.TemporaryDirectory() as tmp_dir: - with warnings.catch_warnings(record=True) as captured_warnings: + with self.assertRaises(ValueError) as exc: config.save_pretrained(tmp_dir) - self.assertEqual(len(captured_warnings), 1) - self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message)) + self.assertTrue("Fix these issues to save the configuration." in str(exc.exception)) self.assertTrue(len(os.listdir(tmp_dir)) == 0) # greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is @@ -167,13 +166,12 @@ class GenerationConfigTest(unittest.TestCase): config = GenerationConfig() config.num_return_sequences = 2 with tempfile.TemporaryDirectory() as tmp_dir: - with warnings.catch_warnings(record=True) as captured_warnings: + with self.assertRaises(ValueError) as exc: config.save_pretrained(tmp_dir) - self.assertEqual(len(captured_warnings), 1) - self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message)) + self.assertTrue("Fix these issues to save the configuration." in str(exc.exception)) self.assertTrue(len(os.listdir(tmp_dir)) == 0) - # final check: no warnings thrown if it is correct, and file is saved + # final check: no warnings/exceptions thrown if it is correct, and file is saved config = GenerationConfig() with tempfile.TemporaryDirectory() as tmp_dir: with warnings.catch_warnings(record=True) as captured_warnings: