From 1eda4a410298d57156d44bfc39a6001a72554412 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 23 Jan 2023 16:21:44 +0000 Subject: [PATCH] Generate: save generation config with the models' `.save_pretrained()` (#21264) --- src/transformers/modeling_flax_utils.py | 2 + src/transformers/modeling_tf_utils.py | 2 + src/transformers/modeling_utils.py | 2 + tests/generation/test_configuration_utils.py | 80 +++++++++++++++++++- tests/test_modeling_common.py | 9 +++ tests/test_modeling_flax_common.py | 9 ++- tests/test_modeling_tf_common.py | 16 +++- 7 files changed, 117 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index a643b43ab6..ade1a5063b 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -1032,6 +1032,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): custom_object_save(self, save_directory, config=self.config) self.config.save_pretrained(save_directory) + if self.can_generate(): + self.generation_config.save_pretrained(save_directory) # save model output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 6572a0f859..f27373bd78 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -2306,6 +2306,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu custom_object_save(self, save_directory, config=self.config) self.config.save_pretrained(save_directory) + if self.can_generate(): + self.generation_config.save_pretrained(save_directory) # If we save using the predefined names, we can load using `from_pretrained` weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 77f0bc117e..4017305842 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1655,6 +1655,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Save the config if is_main_process: model_to_save.config.save_pretrained(save_directory) + if self.can_generate(): + model_to_save.generation_config.save_pretrained(save_directory) # Save the model if state_dict is None: diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index 004720e110..fcc481f209 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -17,11 +17,14 @@ import copy import tempfile import unittest +from huggingface_hub import HfFolder, delete_repo, set_access_token from parameterized import parameterized +from requests.exceptions import HTTPError from transformers import AutoConfig, GenerationConfig +from transformers.testing_utils import TOKEN, USER, is_staging_test -class LogitsProcessorTest(unittest.TestCase): +class GenerationConfigTest(unittest.TestCase): @parameterized.expand([(None,), ("foo.json",)]) def test_save_load_config(self, config_name): config = GenerationConfig( @@ -74,3 +77,78 @@ class LogitsProcessorTest(unittest.TestCase): # `.update()` returns a dictionary of unused kwargs self.assertEqual(unused_kwargs, {"foo": "bar"}) + + +@is_staging_test +class ConfigPushToHubTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._token = TOKEN + set_access_token(TOKEN) + HfFolder.save_token(TOKEN) + + @classmethod + def tearDownClass(cls): + try: + delete_repo(token=cls._token, repo_id="test-generation-config") + except HTTPError: + pass + + try: + delete_repo(token=cls._token, repo_id="valid_org/test-generation-config-org") + except HTTPError: + pass + + def test_push_to_hub(self): + config = GenerationConfig( + do_sample=True, + temperature=0.7, + length_penalty=1.0, + ) + config.push_to_hub("test-generation-config", use_auth_token=self._token) + + new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config") + for k, v in config.to_dict().items(): + if k != "transformers_version": + self.assertEqual(v, getattr(new_config, k)) + + # Reset repo + delete_repo(token=self._token, repo_id="test-generation-config") + + # Push to hub via save_pretrained + with tempfile.TemporaryDirectory() as tmp_dir: + config.save_pretrained( + tmp_dir, repo_id="test-generation-config", push_to_hub=True, use_auth_token=self._token + ) + + new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config") + for k, v in config.to_dict().items(): + if k != "transformers_version": + self.assertEqual(v, getattr(new_config, k)) + + def test_push_to_hub_in_organization(self): + config = GenerationConfig( + do_sample=True, + temperature=0.7, + length_penalty=1.0, + ) + config.push_to_hub("valid_org/test-generation-config-org", use_auth_token=self._token) + + new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org") + for k, v in config.to_dict().items(): + if k != "transformers_version": + self.assertEqual(v, getattr(new_config, k)) + + # Reset repo + delete_repo(token=self._token, repo_id="valid_org/test-generation-config-org") + + # Push to hub via save_pretrained + with tempfile.TemporaryDirectory() as tmp_dir: + config.save_pretrained( + tmp_dir, repo_id="valid_org/test-generation-config-org", push_to_hub=True, use_auth_token=self._token + ) + + new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org") + for k, v in config.to_dict().items(): + if k != "transformers_version": + self.assertEqual(v, getattr(new_config, k)) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1de50c8f90..01b7e47b88 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -63,6 +63,8 @@ from transformers.testing_utils import ( torch_device, ) from transformers.utils import ( + CONFIG_NAME, + GENERATION_CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, @@ -275,6 +277,13 @@ class ModelTesterMixin: with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) + + # the config file (and the generation config file, if it can generate) should be saved + self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME))) + self.assertEqual( + model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME)) + ) + model = model_class.from_pretrained(tmpdirname) model.to(torch_device) with torch.no_grad(): diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 81ae330746..fe1cb69435 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -36,7 +36,7 @@ from transformers.testing_utils import ( require_flax, torch_device, ) -from transformers.utils import logging +from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging from transformers.utils.generic import ModelOutput @@ -395,6 +395,13 @@ class FlaxModelTesterMixin: # verify that normal save_pretrained works as expected with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) + + # the config file (and the generation config file, if it can generate) should be saved + self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME))) + self.assertEqual( + model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME)) + ) + model_loaded = model_class.from_pretrained(tmpdirname) outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple() diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 178c50de18..b1359142bb 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -50,7 +50,14 @@ from transformers.testing_utils import ( # noqa: F401 tooslow, torch_device, ) -from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging +from transformers.utils import ( + CONFIG_NAME, + GENERATION_CONFIG_NAME, + SAFE_WEIGHTS_NAME, + TF2_WEIGHTS_INDEX_NAME, + TF2_WEIGHTS_NAME, + logging, +) from transformers.utils.generic import ModelOutput @@ -226,6 +233,13 @@ class TFModelTesterMixin: with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname, saved_model=False) + + # the config file (and the generation config file, if it can generate) should be saved + self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME))) + self.assertEqual( + model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME)) + ) + model = model_class.from_pretrained(tmpdirname) after_outputs = model(self._prepare_for_class(inputs_dict, model_class))