From b0c5660e881a562ba526f70639ff480a392fce71 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 3 Oct 2024 16:45:14 +0100 Subject: [PATCH] Config: lower `save_pretrained` exception to warning (#33906) * lower to warning * msg * make fixup * rm extra comma --- src/transformers/configuration_utils.py | 9 ++++++--- tests/utils/test_configuration_utils.py | 5 +++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 92e5425e95..8d8784be39 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -380,11 +380,14 @@ class PretrainedConfig(PushToHubMixin): non_default_generation_parameters = self._get_non_default_generation_parameters() if len(non_default_generation_parameters) > 0: - raise ValueError( + # TODO (joao): this should be an exception if the user has modified the loaded config. See #33886 + warnings.warn( "Some non-default generation parameters are set in the model config. These should go into either a) " "`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file " - "(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) " - f"\nNon-default generation parameters: {str(non_default_generation_parameters)}" + "(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model)." + "This warning will become an exception in the future." + f"\nNon-default generation parameters: {str(non_default_generation_parameters)}", + UserWarning, ) os.makedirs(save_directory, exist_ok=True) diff --git a/tests/utils/test_configuration_utils.py b/tests/utils/test_configuration_utils.py index 76394daf9c..d2701bf35e 100644 --- a/tests/utils/test_configuration_utils.py +++ b/tests/utils/test_configuration_utils.py @@ -313,11 +313,12 @@ class ConfigTestUtils(unittest.TestCase): old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo) self.assertEqual(old_configuration.hidden_size, 768) - def test_saving_config_with_custom_generation_kwargs_raises_exception(self): + def test_saving_config_with_custom_generation_kwargs_raises_warning(self): config = BertConfig(min_length=3) # `min_length = 3` is a non-default generation kwarg with tempfile.TemporaryDirectory() as tmp_dir: - with self.assertRaises(ValueError): + with self.assertWarns(UserWarning) as cm: config.save_pretrained(tmp_dir) + self.assertIn("min_length", str(cm.warning)) def test_get_non_default_generation_parameters(self): config = BertConfig()