Config: warning when saving generation kwargs in the model config (#28514)

This commit is contained in:
Joao Gante
2024-01-16 18:31:01 +00:00
committed by GitHub
parent 7142bdfa90
commit f4f57f9dfa
6 changed files with 107 additions and 32 deletions

View File

@@ -529,7 +529,7 @@ class GenerationIntegrationTestsMixin:
pixel_values = floats_tensor((2, 3, 30, 30))
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2")
model.config.decoder.eos_token_id = None
model.generation_config.eos_token_id = None
if is_pt:
pixel_values = pixel_values.to(torch_device)
model = model.to(torch_device)

View File

@@ -296,3 +296,19 @@ class ConfigTestUtils(unittest.TestCase):
old_transformers.configuration_utils.__version__ = "v3.0.0"
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
self.assertEqual(old_configuration.hidden_size, 768)
def test_saving_config_with_custom_generation_kwargs_raises_warning(self):
config = BertConfig(min_length=3) # `min_length = 3` is a non-default generation kwarg
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertLogs("transformers.configuration_utils", level="WARNING") as logs:
config.save_pretrained(tmp_dir)
self.assertEqual(len(logs.output), 1)
self.assertIn("min_length", logs.output[0])
def test_has_non_default_generation_parameters(self):
config = BertConfig()
self.assertFalse(config._has_non_default_generation_parameters())
config = BertConfig(min_length=3)
self.assertTrue(config._has_non_default_generation_parameters())
config = BertConfig(min_length=0) # `min_length = 0` is a default generation kwarg
self.assertFalse(config._has_non_default_generation_parameters())

View File

@@ -1230,6 +1230,15 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_modifying_model_config_causes_warning_saving_generation_config(self):
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.config.top_k = 1
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertLogs("transformers.modeling_utils", level="WARNING") as logs:
model.save_pretrained(tmp_dir)
self.assertEqual(len(logs.output), 1)
self.assertIn("Your generation config was originally created from the model config", logs.output[0])
@slow
@require_torch