From b1ea6b4bf57a9117a6bb24b4b7c8856b1e05dee3 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 13 Jun 2023 17:59:21 +0100 Subject: [PATCH] Generate: GenerationConfig can overwrite attributes at from_pretrained time (#24238) Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../generation/configuration_utils.py | 9 ++++--- tests/generation/test_configuration_utils.py | 25 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 9f3bedcdeb..d024ff4718 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -288,7 +288,8 @@ class GenerationConfig(PushToHubMixin): # Additional attributes without default values if not self._from_model_config: - # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a model's default configuration file + # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a + # model's default configuration file for key, value in kwargs.items(): try: setattr(self, key, value) @@ -569,9 +570,9 @@ class GenerationConfig(PushToHubMixin): if "_commit_hash" in kwargs and "_commit_hash" in config_dict: kwargs["_commit_hash"] = config_dict["_commit_hash"] - # remove all the arguments that are in the config_dict - - config = cls(**config_dict, **kwargs) + # The line below allows model-specific config to be loaded as well through kwargs, with safety checks. + # See https://github.com/huggingface/transformers/pull/21269 + config = cls(**{**config_dict, **kwargs}) unused_kwargs = config.update(**kwargs) logger.info(f"Generate config {config}") diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index a12b359682..c2dd7b005b 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -93,6 +93,31 @@ class GenerationConfigTest(unittest.TestCase): generation_config = GenerationConfig.from_model_config(new_config) assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config + def test_kwarg_init(self): + """Tests that we can overwrite attributes at `from_pretrained` time.""" + default_config = GenerationConfig() + self.assertEqual(default_config.temperature, 1.0) + self.assertEqual(default_config.do_sample, False) + self.assertEqual(default_config.num_beams, 1) + + config = GenerationConfig( + do_sample=True, + temperature=0.7, + length_penalty=1.0, + bad_words_ids=[[1, 2, 3], [4, 5]], + ) + self.assertEqual(config.temperature, 0.7) + self.assertEqual(config.do_sample, True) + self.assertEqual(config.num_beams, 1) + + with tempfile.TemporaryDirectory() as tmp_dir: + config.save_pretrained(tmp_dir) + loaded_config = GenerationConfig.from_pretrained(tmp_dir, temperature=1.0) + + self.assertEqual(loaded_config.temperature, 1.0) + self.assertEqual(loaded_config.do_sample, True) + self.assertEqual(loaded_config.num_beams, 1) # default value + @is_staging_test class ConfigPushToHubTester(unittest.TestCase):