Config: warning when saving generation kwargs in the model config (#28514)

This commit is contained in:
Joao Gante
2024-01-16 18:31:01 +00:00
committed by GitHub
parent 7142bdfa90
commit f4f57f9dfa
6 changed files with 107 additions and 32 deletions

View File

@@ -1230,6 +1230,15 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_modifying_model_config_causes_warning_saving_generation_config(self):
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.config.top_k = 1
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertLogs("transformers.modeling_utils", level="WARNING") as logs:
model.save_pretrained(tmp_dir)
self.assertEqual(len(logs.output), 1)
self.assertIn("Your generation config was originally created from the model config", logs.output[0])
@slow
@require_torch