Config: lower save_pretrained exception to warning (#33906)
* lower to warning * msg * make fixup * rm extra comma
This commit is contained in:
@@ -380,11 +380,14 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
|
|
||||||
non_default_generation_parameters = self._get_non_default_generation_parameters()
|
non_default_generation_parameters = self._get_non_default_generation_parameters()
|
||||||
if len(non_default_generation_parameters) > 0:
|
if len(non_default_generation_parameters) > 0:
|
||||||
raise ValueError(
|
# TODO (joao): this should be an exception if the user has modified the loaded config. See #33886
|
||||||
|
warnings.warn(
|
||||||
"Some non-default generation parameters are set in the model config. These should go into either a) "
|
"Some non-default generation parameters are set in the model config. These should go into either a) "
|
||||||
"`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file "
|
"`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file "
|
||||||
"(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) "
|
"(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model)."
|
||||||
f"\nNon-default generation parameters: {str(non_default_generation_parameters)}"
|
"This warning will become an exception in the future."
|
||||||
|
f"\nNon-default generation parameters: {str(non_default_generation_parameters)}",
|
||||||
|
UserWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|||||||
@@ -313,11 +313,12 @@ class ConfigTestUtils(unittest.TestCase):
|
|||||||
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
||||||
self.assertEqual(old_configuration.hidden_size, 768)
|
self.assertEqual(old_configuration.hidden_size, 768)
|
||||||
|
|
||||||
def test_saving_config_with_custom_generation_kwargs_raises_exception(self):
|
def test_saving_config_with_custom_generation_kwargs_raises_warning(self):
|
||||||
config = BertConfig(min_length=3) # `min_length = 3` is a non-default generation kwarg
|
config = BertConfig(min_length=3) # `min_length = 3` is a non-default generation kwarg
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
with self.assertRaises(ValueError):
|
with self.assertWarns(UserWarning) as cm:
|
||||||
config.save_pretrained(tmp_dir)
|
config.save_pretrained(tmp_dir)
|
||||||
|
self.assertIn("min_length", str(cm.warning))
|
||||||
|
|
||||||
def test_get_non_default_generation_parameters(self):
|
def test_get_non_default_generation_parameters(self):
|
||||||
config = BertConfig()
|
config = BertConfig()
|
||||||
|
|||||||
Reference in New Issue
Block a user