Generate: throw warning when return_dict_in_generate is False but should be True (#33146)

This commit is contained in:
Joao Gante
2024-08-31 10:47:08 +01:00
committed by GitHub
parent 746104ba6f
commit eb5b968c5d
2 changed files with 27 additions and 3 deletions

View File

@@ -136,6 +136,10 @@ class GenerationConfigTest(unittest.TestCase):
GenerationConfig(do_sample=False, temperature=0.5)
self.assertEqual(len(captured_warnings), 1)
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(return_dict_in_generate=False, output_scores=True)
self.assertEqual(len(captured_warnings), 1)
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
# that is done by unsetting the parameter (i.e. setting it to None)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)