Generate: GenerationConfig can overwrite attributes at from_pretrained time (#24238)
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -93,6 +93,31 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
generation_config = GenerationConfig.from_model_config(new_config)
|
||||
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
|
||||
|
||||
def test_kwarg_init(self):
|
||||
"""Tests that we can overwrite attributes at `from_pretrained` time."""
|
||||
default_config = GenerationConfig()
|
||||
self.assertEqual(default_config.temperature, 1.0)
|
||||
self.assertEqual(default_config.do_sample, False)
|
||||
self.assertEqual(default_config.num_beams, 1)
|
||||
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
bad_words_ids=[[1, 2, 3], [4, 5]],
|
||||
)
|
||||
self.assertEqual(config.temperature, 0.7)
|
||||
self.assertEqual(config.do_sample, True)
|
||||
self.assertEqual(config.num_beams, 1)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir)
|
||||
loaded_config = GenerationConfig.from_pretrained(tmp_dir, temperature=1.0)
|
||||
|
||||
self.assertEqual(loaded_config.temperature, 1.0)
|
||||
self.assertEqual(loaded_config.do_sample, True)
|
||||
self.assertEqual(loaded_config.num_beams, 1) # default value
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user