Generate: unset GenerationConfig parameters do not raise warning (#29119)

This commit is contained in:
Joao Gante
2024-02-20 11:34:31 +00:00
committed by GitHub
parent 7d312ad2e9
commit a7755d2409
6 changed files with 43 additions and 25 deletions

View File

@@ -124,26 +124,44 @@ class GenerationConfigTest(unittest.TestCase):
"""
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
"""
# Case 1: A correct configuration will not throw any warning
# A correct configuration will not throw any warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig()
self.assertEqual(len(captured_warnings), 0)
# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future.
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(temperature=0.5)
GenerationConfig(do_sample=False, temperature=0.5)
self.assertEqual(len(captured_warnings), 1)
# Case 3: Impossible sets of contraints/parameters will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(num_return_sequences=2)
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
# that is done by unsetting the parameter (i.e. setting it to None)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# BAD - 0.9 means it is still set, we should warn
generation_config_bad_temperature.update(temperature=0.9)
self.assertEqual(len(captured_warnings), 1)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
generation_config_bad_temperature.update(temperature=1.0)
self.assertEqual(len(captured_warnings), 0)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# OK - None means it is unset, nothing to warn about
generation_config_bad_temperature.update(temperature=None)
self.assertEqual(len(captured_warnings), 0)
# Case 4: Passing `generate()`-only flags to `validate` will raise an exception
# Impossible sets of contraints/parameters will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)
# Passing `generate()`-only flags to `validate` will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(logits_processor="foo")
# Case 5: Model-specific parameters will NOT raise an exception or a warning
# Model-specific parameters will NOT raise an exception or a warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(foo="bar")
self.assertEqual(len(captured_warnings), 0)