Pipeline: no side-effects on model.config and model.generation_config 🔫 (#33480)

This commit is contained in:
Joao Gante
2024-09-18 15:43:06 +01:00
committed by GitHub
parent fc83a4d459
commit 7542fac2c7
13 changed files with 132 additions and 30 deletions

View File

@@ -1715,6 +1715,38 @@ class ModelUtilsTest(TestCasePlus):
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
)
def test_save_and_load_config_with_custom_generation(self):
"""
Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter
that gets moved to the generation config and reset on the model config)
"""
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
# The default for `num_beams` is 1 and `early_stopping` is False
self.assertTrue(model.config.num_beams == 1)
self.assertTrue(model.config.early_stopping is False)
# When we save the model, this custom parameter should be moved to the generation config AND the model
# config should contain `None`
model.config.num_beams = 2
model.config.early_stopping = True
self.assertTrue(model.generation_config.num_beams == 1) # unmodified generation config
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
new_model = T5ForConditionalGeneration.from_pretrained(tmp_dir)
# moved to generation config
self.assertTrue(new_model.generation_config.num_beams == 2)
self.assertTrue(new_model.generation_config.early_stopping is True)
# reset in the model config
self.assertTrue(new_model.config.num_beams is None)
self.assertTrue(new_model.config.early_stopping is None)
# Sanity check: We can run `generate` with the new model without any warnings
random_ids = torch.randint(0, 100, (1, 5))
with warnings.catch_warnings(record=True) as w:
new_model.generate(random_ids, max_new_tokens=3)
self.assertTrue(len(w) == 0)
@slow
@require_torch