Generation: strict generation config validation at save time (#25411)

* strict gen config save; Add tests

* add note that the warning will be an exception in v4.34
This commit is contained in:
Joao Gante
2023-08-10 10:42:34 +01:00
committed by GitHub
parent 16edf4d9fd
commit 123ad5363f
2 changed files with 58 additions and 7 deletions

View File

@@ -14,8 +14,10 @@
# limitations under the License.
import copy
import os
import tempfile
import unittest
import warnings
from huggingface_hub import HfFolder, delete_repo
from parameterized import parameterized
@@ -118,6 +120,39 @@ class GenerationConfigTest(unittest.TestCase):
self.assertEqual(loaded_config.do_sample, True)
self.assertEqual(loaded_config.num_beams, 1) # default value
def test_refuse_to_save(self):
"""Tests that we refuse to save a generation config that fails validation."""
# setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that
# is caught, doesn't save, and raises a warning
config = GenerationConfig()
config.temperature = 0.5
with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 1)
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
# caught, doesn't save, and raises a warning
config = GenerationConfig()
config.num_return_sequences = 2
with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 1)
self.assertTrue("Fix these issues to save the configuration." in str(captured_warnings[0].message))
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
# final check: no warnings thrown if it is correct, and file is saved
config = GenerationConfig()
with tempfile.TemporaryDirectory() as tmp_dir:
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 0)
self.assertTrue(len(os.listdir(tmp_dir)) == 1)
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):