[GenerationConfig] add additional kwargs handling (#21269)
* add additional kwargs handling * fix issue when serializing * correct order of kwargs removal for serialization in from dict * add `dict_torch_dtype_to_str` in case a dtype is needed for generation * add condition when adding the kwargs : not from config * Add comment based on review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add test function * default None when poping arg Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -282,6 +282,16 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self._commit_hash = kwargs.pop("_commit_hash", None)
|
self._commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
self.transformers_version = kwargs.pop("transformers_version", __version__)
|
self.transformers_version = kwargs.pop("transformers_version", __version__)
|
||||||
|
|
||||||
|
# Additional attributes without default values
|
||||||
|
if not self._from_model_config:
|
||||||
|
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a model's default configuration file
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
try:
|
||||||
|
setattr(self, key, value)
|
||||||
|
except AttributeError as err:
|
||||||
|
logger.error(f"Can't set {key} with value {value} for {self}")
|
||||||
|
raise err
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
self_dict = self.__dict__.copy()
|
self_dict = self.__dict__.copy()
|
||||||
other_dict = other.__dict__.copy()
|
other_dict = other.__dict__.copy()
|
||||||
@@ -537,7 +547,9 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
|
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
|
||||||
kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
||||||
|
|
||||||
config = cls(**config_dict)
|
# remove all the arguments that are in the config_dict
|
||||||
|
|
||||||
|
config = cls(**config_dict, **kwargs)
|
||||||
unused_kwargs = config.update(**kwargs)
|
unused_kwargs = config.update(**kwargs)
|
||||||
|
|
||||||
logger.info(f"Generate config {config}")
|
logger.info(f"Generate config {config}")
|
||||||
@@ -546,6 +558,18 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
else:
|
else:
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
|
||||||
|
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
|
||||||
|
string, which can then be stored in the json format.
|
||||||
|
"""
|
||||||
|
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
|
||||||
|
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
||||||
|
for value in d.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
self.dict_torch_dtype_to_str(value)
|
||||||
|
|
||||||
def to_diff_dict(self) -> Dict[str, Any]:
|
def to_diff_dict(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Removes all attributes from config which correspond to the default config attributes for better readability and
|
Removes all attributes from config which correspond to the default config attributes for better readability and
|
||||||
@@ -566,6 +590,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]:
|
if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]:
|
||||||
serializable_config_dict[key] = value
|
serializable_config_dict[key] = value
|
||||||
|
|
||||||
|
self.dict_torch_dtype_to_str(serializable_config_dict)
|
||||||
return serializable_config_dict
|
return serializable_config_dict
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
@@ -582,6 +607,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
# Transformers version when serializing this file
|
# Transformers version when serializing this file
|
||||||
output["transformers_version"] = __version__
|
output["transformers_version"] = __version__
|
||||||
|
|
||||||
|
self.dict_torch_dtype_to_str(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def to_json_string(self, use_diff: bool = True) -> str:
|
def to_json_string(self, use_diff: bool = True) -> str:
|
||||||
@@ -630,7 +656,8 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
[`GenerationConfig`]: The configuration object instantiated from those parameters.
|
[`GenerationConfig`]: The configuration object instantiated from those parameters.
|
||||||
"""
|
"""
|
||||||
config_dict = model_config.to_dict()
|
config_dict = model_config.to_dict()
|
||||||
config = cls.from_dict(config_dict, return_unused_kwargs=False)
|
config_dict.pop("_from_model_config", None)
|
||||||
|
config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
|
||||||
|
|
||||||
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
||||||
# generation config.
|
# generation config.
|
||||||
@@ -642,7 +669,6 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
|
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
|
||||||
setattr(config, attr, decoder_config[attr])
|
setattr(config, attr, decoder_config[attr])
|
||||||
|
|
||||||
config._from_model_config = True
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def update(self, **kwargs):
|
def update(self, **kwargs):
|
||||||
|
|||||||
@@ -78,6 +78,20 @@ class GenerationConfigTest(unittest.TestCase):
|
|||||||
# `.update()` returns a dictionary of unused kwargs
|
# `.update()` returns a dictionary of unused kwargs
|
||||||
self.assertEqual(unused_kwargs, {"foo": "bar"})
|
self.assertEqual(unused_kwargs, {"foo": "bar"})
|
||||||
|
|
||||||
|
def test_initialize_new_kwargs(self):
|
||||||
|
generation_config = GenerationConfig()
|
||||||
|
generation_config.foo = "bar"
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||||
|
generation_config.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||||
|
# update_kwargs was used to update the config on valid attributes
|
||||||
|
self.assertEqual(new_config.foo, "bar")
|
||||||
|
|
||||||
|
generation_config = GenerationConfig.from_model_config(new_config)
|
||||||
|
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
|
||||||
|
|
||||||
|
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
class ConfigPushToHubTester(unittest.TestCase):
|
class ConfigPushToHubTester(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user