[GenerationConfig] add additional kwargs handling (#21269)
* add additional kwargs handling * fix issue when serializing * correct order of kwargs removal for serialization in from dict * add `dict_torch_dtype_to_str` in case a dtype is needed for generation * add condition when adding the kwargs : not from config * Add comment based on review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add test function * default None when poping arg Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -78,6 +78,20 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
# `.update()` returns a dictionary of unused kwargs
|
||||
self.assertEqual(unused_kwargs, {"foo": "bar"})
|
||||
|
||||
def test_initialize_new_kwargs(self):
|
||||
generation_config = GenerationConfig()
|
||||
generation_config.foo = "bar"
|
||||
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
# update_kwargs was used to update the config on valid attributes
|
||||
self.assertEqual(new_config.foo, "bar")
|
||||
|
||||
generation_config = GenerationConfig.from_model_config(new_config)
|
||||
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user