Forbid PretrainedConfig from saving generate parameters; Update deprecations in generate-related code 🧹 (#32659)

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Joao Gante
2024-08-23 11:12:53 +01:00
committed by GitHub
parent 22e6f14525
commit 970a16ec7f
53 changed files with 195 additions and 670 deletions

View File

@@ -23,6 +23,7 @@ import threading
import unittest
import unittest.mock as mock
import uuid
import warnings
from pathlib import Path
import requests
@@ -1599,14 +1600,30 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_modifying_model_config_causes_warning_saving_generation_config(self):
def test_modifying_model_config_gets_moved_to_generation_config(self):
"""
Calling `model.save_pretrained` should move the changes made to `generate` parameterization in the model config
to the generation config.
"""
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
model.config.top_k = 1
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertLogs("transformers.modeling_utils", level="WARNING") as logs:
# Initially, the repetition penalty has its default value in `model.config`. The `model.generation_config` will
# have the exact same default
self.assertTrue(model.config.repetition_penalty == 1.0)
self.assertTrue(model.generation_config.repetition_penalty == 1.0)
# If the user attempts to save a custom generation parameter:
model.config.repetition_penalty = 3.0
with warnings.catch_warnings(record=True) as warning_list:
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
self.assertEqual(len(logs.output), 1)
self.assertIn("Your generation config was originally created from the model config", logs.output[0])
# 1 - That parameter will be removed from `model.config`. We don't want to use `model.config` to store
# generative parameters, and the old default (1.0) would no longer relect the user's wishes.
self.assertTrue(model.config.repetition_penalty is None)
# 2 - That parameter will be set in `model.generation_config` instead.
self.assertTrue(model.generation_config.repetition_penalty == 3.0)
# 3 - The user will see a warning regarding the custom parameter that has been moved.
self.assertTrue(len(warning_list) == 1)
self.assertTrue("Moving the following attributes" in str(warning_list[0].message))
self.assertTrue("repetition_penalty" in str(warning_list[0].message))
@require_safetensors
def test_model_from_pretrained_from_mlx(self):