Generate: unset GenerationConfig parameters do not raise warning (#29119)
This commit is contained in:
@@ -124,26 +124,44 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
"""
|
||||
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
|
||||
"""
|
||||
# Case 1: A correct configuration will not throw any warning
|
||||
# A correct configuration will not throw any warning
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
GenerationConfig()
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
|
||||
# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
|
||||
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
|
||||
# parameters with `do_sample=False`). May be escalated to an error in the future.
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
GenerationConfig(temperature=0.5)
|
||||
GenerationConfig(do_sample=False, temperature=0.5)
|
||||
self.assertEqual(len(captured_warnings), 1)
|
||||
|
||||
# Case 3: Impossible sets of contraints/parameters will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(num_return_sequences=2)
|
||||
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
|
||||
# that is done by unsetting the parameter (i.e. setting it to None)
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
# BAD - 0.9 means it is still set, we should warn
|
||||
generation_config_bad_temperature.update(temperature=0.9)
|
||||
self.assertEqual(len(captured_warnings), 1)
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
|
||||
generation_config_bad_temperature.update(temperature=1.0)
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
# OK - None means it is unset, nothing to warn about
|
||||
generation_config_bad_temperature.update(temperature=None)
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
|
||||
# Case 4: Passing `generate()`-only flags to `validate` will raise an exception
|
||||
# Impossible sets of contraints/parameters will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)
|
||||
|
||||
# Passing `generate()`-only flags to `validate` will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(logits_processor="foo")
|
||||
|
||||
# Case 5: Model-specific parameters will NOT raise an exception or a warning
|
||||
# Model-specific parameters will NOT raise an exception or a warning
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
GenerationConfig(foo="bar")
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
|
||||
Reference in New Issue
Block a user