[GenerationConfig] add additional kwargs handling (#21269)

* add additional kwargs handling

* fix issue when serializing

* correct order of kwargs removal for serialization in from dict

* add `dict_torch_dtype_to_str` in case a dtype is needed for generation

* add condition when adding the kwargs : not from config

* Add comment based on review

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* add test function

* default None when poping arg

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Arthur
2023-01-24 19:04:42 +01:00
committed by GitHub
parent 9286039c2a
commit 94a7edd938
2 changed files with 43 additions and 3 deletions

View File

@@ -78,6 +78,20 @@ class GenerationConfigTest(unittest.TestCase):
# `.update()` returns a dictionary of unused kwargs
self.assertEqual(unused_kwargs, {"foo": "bar"})
def test_initialize_new_kwargs(self):
generation_config = GenerationConfig()
generation_config.foo = "bar"
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
# update_kwargs was used to update the config on valid attributes
self.assertEqual(new_config.foo, "bar")
generation_config = GenerationConfig.from_model_config(new_config)
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):