Generate: save generation config with the models' .save_pretrained() (#21264)
This commit is contained in:
@@ -36,7 +36,7 @@ from transformers.testing_utils import (
|
||||
require_flax,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
|
||||
from transformers.utils.generic import ModelOutput
|
||||
|
||||
|
||||
@@ -395,6 +395,13 @@ class FlaxModelTesterMixin:
|
||||
# verify that normal save_pretrained works as expected
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# the config file (and the generation config file, if it can generate) should be saved
|
||||
self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
|
||||
self.assertEqual(
|
||||
model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
|
||||
)
|
||||
|
||||
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
Reference in New Issue
Block a user