Generate: save generation config with the models' .save_pretrained() (#21264)
This commit is contained in:
@@ -50,7 +50,14 @@ from transformers.testing_utils import ( # noqa: F401
|
||||
tooslow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
|
||||
from transformers.utils import (
|
||||
CONFIG_NAME,
|
||||
GENERATION_CONFIG_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
logging,
|
||||
)
|
||||
from transformers.utils.generic import ModelOutput
|
||||
|
||||
|
||||
@@ -226,6 +233,13 @@ class TFModelTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, saved_model=False)
|
||||
|
||||
# the config file (and the generation config file, if it can generate) should be saved
|
||||
self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
|
||||
self.assertEqual(
|
||||
model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
|
||||
)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
after_outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user