Generate: save generation config with the models' .save_pretrained() (#21264)

This commit is contained in:
Joao Gante
2023-01-23 16:21:44 +00:00
committed by GitHub
parent cf1a1eed70
commit 1eda4a4102
7 changed files with 117 additions and 3 deletions

View File

@@ -50,7 +50,14 @@ from transformers.testing_utils import ( # noqa: F401
tooslow,
torch_device,
)
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
from transformers.utils import (
CONFIG_NAME,
GENERATION_CONFIG_NAME,
SAFE_WEIGHTS_NAME,
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
logging,
)
from transformers.utils.generic import ModelOutput
@@ -226,6 +233,13 @@ class TFModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=False)
# the config file (and the generation config file, if it can generate) should be saved
self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
self.assertEqual(
model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
)
model = model_class.from_pretrained(tmpdirname)
after_outputs = model(self._prepare_for_class(inputs_dict, model_class))