Pipeline: no side-effects on model.config and model.generation_config 🔫 (#33480)
This commit is contained in:
@@ -1715,6 +1715,38 @@ class ModelUtilsTest(TestCasePlus):
|
||||
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||
)
|
||||
|
||||
def test_save_and_load_config_with_custom_generation(self):
|
||||
"""
|
||||
Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter
|
||||
that gets moved to the generation config and reset on the model config)
|
||||
"""
|
||||
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||
|
||||
# The default for `num_beams` is 1 and `early_stopping` is False
|
||||
self.assertTrue(model.config.num_beams == 1)
|
||||
self.assertTrue(model.config.early_stopping is False)
|
||||
|
||||
# When we save the model, this custom parameter should be moved to the generation config AND the model
|
||||
# config should contain `None`
|
||||
model.config.num_beams = 2
|
||||
model.config.early_stopping = True
|
||||
self.assertTrue(model.generation_config.num_beams == 1) # unmodified generation config
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
new_model = T5ForConditionalGeneration.from_pretrained(tmp_dir)
|
||||
# moved to generation config
|
||||
self.assertTrue(new_model.generation_config.num_beams == 2)
|
||||
self.assertTrue(new_model.generation_config.early_stopping is True)
|
||||
# reset in the model config
|
||||
self.assertTrue(new_model.config.num_beams is None)
|
||||
self.assertTrue(new_model.config.early_stopping is None)
|
||||
|
||||
# Sanity check: We can run `generate` with the new model without any warnings
|
||||
random_ids = torch.randint(0, 100, (1, 5))
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
new_model.generate(random_ids, max_new_tokens=3)
|
||||
self.assertTrue(len(w) == 0)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user