Generate: save generation config with the models' .save_pretrained() (#21264)

This commit is contained in:
Joao Gante
2023-01-23 16:21:44 +00:00
committed by GitHub
parent cf1a1eed70
commit 1eda4a4102
7 changed files with 117 additions and 3 deletions

View File

@@ -36,7 +36,7 @@ from transformers.testing_utils import (
require_flax,
torch_device,
)
from transformers.utils import logging
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
from transformers.utils.generic import ModelOutput
@@ -395,6 +395,13 @@ class FlaxModelTesterMixin:
# verify that normal save_pretrained works as expected
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# the config file (and the generation config file, if it can generate) should be saved
self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
self.assertEqual(
model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
)
model_loaded = model_class.from_pretrained(tmpdirname)
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()