Generate: GenerationConfig can overwrite attributes at from_pretrained time (#24238)

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Joao Gante
2023-06-13 17:59:21 +01:00
committed by GitHub
parent 7bb6933b9d
commit b1ea6b4bf5
2 changed files with 30 additions and 4 deletions

View File

@@ -93,6 +93,31 @@ class GenerationConfigTest(unittest.TestCase):
generation_config = GenerationConfig.from_model_config(new_config)
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
def test_kwarg_init(self):
"""Tests that we can overwrite attributes at `from_pretrained` time."""
default_config = GenerationConfig()
self.assertEqual(default_config.temperature, 1.0)
self.assertEqual(default_config.do_sample, False)
self.assertEqual(default_config.num_beams, 1)
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
bad_words_ids=[[1, 2, 3], [4, 5]],
)
self.assertEqual(config.temperature, 0.7)
self.assertEqual(config.do_sample, True)
self.assertEqual(config.num_beams, 1)
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir)
loaded_config = GenerationConfig.from_pretrained(tmp_dir, temperature=1.0)
self.assertEqual(loaded_config.temperature, 1.0)
self.assertEqual(loaded_config.do_sample, True)
self.assertEqual(loaded_config.num_beams, 1) # default value
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):