Generate: GenerationConfig can overwrite attributes at from_pretrained time (#24238)
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -288,7 +288,8 @@ class GenerationConfig(PushToHubMixin):
|
||||
|
||||
# Additional attributes without default values
|
||||
if not self._from_model_config:
|
||||
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a model's default configuration file
|
||||
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
|
||||
# model's default configuration file
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
@@ -569,9 +570,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
|
||||
kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
||||
|
||||
# remove all the arguments that are in the config_dict
|
||||
|
||||
config = cls(**config_dict, **kwargs)
|
||||
# The line below allows model-specific config to be loaded as well through kwargs, with safety checks.
|
||||
# See https://github.com/huggingface/transformers/pull/21269
|
||||
config = cls(**{**config_dict, **kwargs})
|
||||
unused_kwargs = config.update(**kwargs)
|
||||
|
||||
logger.info(f"Generate config {config}")
|
||||
|
||||
@@ -93,6 +93,31 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
generation_config = GenerationConfig.from_model_config(new_config)
|
||||
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
|
||||
|
||||
def test_kwarg_init(self):
|
||||
"""Tests that we can overwrite attributes at `from_pretrained` time."""
|
||||
default_config = GenerationConfig()
|
||||
self.assertEqual(default_config.temperature, 1.0)
|
||||
self.assertEqual(default_config.do_sample, False)
|
||||
self.assertEqual(default_config.num_beams, 1)
|
||||
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
bad_words_ids=[[1, 2, 3], [4, 5]],
|
||||
)
|
||||
self.assertEqual(config.temperature, 0.7)
|
||||
self.assertEqual(config.do_sample, True)
|
||||
self.assertEqual(config.num_beams, 1)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir)
|
||||
loaded_config = GenerationConfig.from_pretrained(tmp_dir, temperature=1.0)
|
||||
|
||||
self.assertEqual(loaded_config.temperature, 1.0)
|
||||
self.assertEqual(loaded_config.do_sample, True)
|
||||
self.assertEqual(loaded_config.num_beams, 1) # default value
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user