Forbid PretrainedConfig from saving generate parameters; Update deprecations in generate-related code 🧹 (#32659)
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -18,7 +18,7 @@ import shutil
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers.testing_utils import CaptureStd, is_pt_tf_cross_test, require_torch
|
||||
from transformers.testing_utils import CaptureStd, require_torch
|
||||
|
||||
|
||||
class CLITest(unittest.TestCase):
|
||||
@@ -33,18 +33,6 @@ class CLITest(unittest.TestCase):
|
||||
self.assertIn("Platform", cs.out)
|
||||
self.assertIn("Using distributed or parallel set-up in script?", cs.out)
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
@patch(
|
||||
"sys.argv", ["fakeprogrampath", "pt-to-tf", "--model-name", "hf-internal-testing/tiny-random-gptj", "--no-pr"]
|
||||
)
|
||||
def test_cli_pt_to_tf(self):
|
||||
import transformers.commands.transformers_cli
|
||||
|
||||
shutil.rmtree("/tmp/hf-internal-testing/tiny-random-gptj", ignore_errors=True) # cleans potential past runs
|
||||
transformers.commands.transformers_cli.main()
|
||||
|
||||
self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5"))
|
||||
|
||||
@require_torch
|
||||
@patch("sys.argv", ["fakeprogrampath", "download", "hf-internal-testing/tiny-random-gptj", "--cache-dir", "/tmp"])
|
||||
def test_cli_download(self):
|
||||
|
||||
@@ -315,21 +315,19 @@ class ConfigTestUtils(unittest.TestCase):
|
||||
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
||||
self.assertEqual(old_configuration.hidden_size, 768)
|
||||
|
||||
def test_saving_config_with_custom_generation_kwargs_raises_warning(self):
|
||||
def test_saving_config_with_custom_generation_kwargs_raises_exception(self):
|
||||
config = BertConfig(min_length=3) # `min_length = 3` is a non-default generation kwarg
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertLogs("transformers.configuration_utils", level="WARNING") as logs:
|
||||
with self.assertRaises(ValueError):
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("min_length", logs.output[0])
|
||||
|
||||
def test_has_non_default_generation_parameters(self):
|
||||
def test_get_non_default_generation_parameters(self):
|
||||
config = BertConfig()
|
||||
self.assertFalse(config._has_non_default_generation_parameters())
|
||||
self.assertFalse(len(config._get_non_default_generation_parameters()) > 0)
|
||||
config = BertConfig(min_length=3)
|
||||
self.assertTrue(config._has_non_default_generation_parameters())
|
||||
self.assertTrue(len(config._get_non_default_generation_parameters()) > 0)
|
||||
config = BertConfig(min_length=0) # `min_length = 0` is a default generation kwarg
|
||||
self.assertFalse(config._has_non_default_generation_parameters())
|
||||
self.assertFalse(len(config._get_non_default_generation_parameters()) > 0)
|
||||
|
||||
def test_loading_config_do_not_raise_future_warnings(self):
|
||||
"""Regression test for https://github.com/huggingface/transformers/issues/31002."""
|
||||
|
||||
@@ -23,6 +23,7 @@ import threading
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import uuid
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
@@ -1599,14 +1600,30 @@ class ModelUtilsTest(TestCasePlus):
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
def test_modifying_model_config_causes_warning_saving_generation_config(self):
|
||||
def test_modifying_model_config_gets_moved_to_generation_config(self):
|
||||
"""
|
||||
Calling `model.save_pretrained` should move the changes made to `generate` parameterization in the model config
|
||||
to the generation config.
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||||
model.config.top_k = 1
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertLogs("transformers.modeling_utils", level="WARNING") as logs:
|
||||
# Initially, the repetition penalty has its default value in `model.config`. The `model.generation_config` will
|
||||
# have the exact same default
|
||||
self.assertTrue(model.config.repetition_penalty == 1.0)
|
||||
self.assertTrue(model.generation_config.repetition_penalty == 1.0)
|
||||
# If the user attempts to save a custom generation parameter:
|
||||
model.config.repetition_penalty = 3.0
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("Your generation config was originally created from the model config", logs.output[0])
|
||||
# 1 - That parameter will be removed from `model.config`. We don't want to use `model.config` to store
|
||||
# generative parameters, and the old default (1.0) would no longer relect the user's wishes.
|
||||
self.assertTrue(model.config.repetition_penalty is None)
|
||||
# 2 - That parameter will be set in `model.generation_config` instead.
|
||||
self.assertTrue(model.generation_config.repetition_penalty == 3.0)
|
||||
# 3 - The user will see a warning regarding the custom parameter that has been moved.
|
||||
self.assertTrue(len(warning_list) == 1)
|
||||
self.assertTrue("Moving the following attributes" in str(warning_list[0].message))
|
||||
self.assertTrue("repetition_penalty" in str(warning_list[0].message))
|
||||
|
||||
@require_safetensors
|
||||
def test_model_from_pretrained_from_mlx(self):
|
||||
|
||||
Reference in New Issue
Block a user