Generate: save generation config with the models' .save_pretrained() (#21264)
This commit is contained in:
@@ -63,6 +63,8 @@ from transformers.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import (
|
||||
CONFIG_NAME,
|
||||
GENERATION_CONFIG_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
@@ -275,6 +277,13 @@ class ModelTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# the config file (and the generation config file, if it can generate) should be saved
|
||||
self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
|
||||
self.assertEqual(
|
||||
model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
|
||||
)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
model.to(torch_device)
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user