Generate: save generation config with the models' .save_pretrained() (#21264)
This commit is contained in:
@@ -1032,6 +1032,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
custom_object_save(self, save_directory, config=self.config)
|
custom_object_save(self, save_directory, config=self.config)
|
||||||
|
|
||||||
self.config.save_pretrained(save_directory)
|
self.config.save_pretrained(save_directory)
|
||||||
|
if self.can_generate():
|
||||||
|
self.generation_config.save_pretrained(save_directory)
|
||||||
|
|
||||||
# save model
|
# save model
|
||||||
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
|
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
|
||||||
|
|||||||
@@ -2306,6 +2306,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
custom_object_save(self, save_directory, config=self.config)
|
custom_object_save(self, save_directory, config=self.config)
|
||||||
|
|
||||||
self.config.save_pretrained(save_directory)
|
self.config.save_pretrained(save_directory)
|
||||||
|
if self.can_generate():
|
||||||
|
self.generation_config.save_pretrained(save_directory)
|
||||||
|
|
||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
|
||||||
|
|||||||
@@ -1655,6 +1655,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Save the config
|
# Save the config
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
model_to_save.config.save_pretrained(save_directory)
|
model_to_save.config.save_pretrained(save_directory)
|
||||||
|
if self.can_generate():
|
||||||
|
model_to_save.generation_config.save_pretrained(save_directory)
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
|
|||||||
@@ -17,11 +17,14 @@ import copy
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from huggingface_hub import HfFolder, delete_repo, set_access_token
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
from transformers import AutoConfig, GenerationConfig
|
from transformers import AutoConfig, GenerationConfig
|
||||||
|
from transformers.testing_utils import TOKEN, USER, is_staging_test
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessorTest(unittest.TestCase):
|
class GenerationConfigTest(unittest.TestCase):
|
||||||
@parameterized.expand([(None,), ("foo.json",)])
|
@parameterized.expand([(None,), ("foo.json",)])
|
||||||
def test_save_load_config(self, config_name):
|
def test_save_load_config(self, config_name):
|
||||||
config = GenerationConfig(
|
config = GenerationConfig(
|
||||||
@@ -74,3 +77,78 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
|
|
||||||
# `.update()` returns a dictionary of unused kwargs
|
# `.update()` returns a dictionary of unused kwargs
|
||||||
self.assertEqual(unused_kwargs, {"foo": "bar"})
|
self.assertEqual(unused_kwargs, {"foo": "bar"})
|
||||||
|
|
||||||
|
|
||||||
|
@is_staging_test
|
||||||
|
class ConfigPushToHubTester(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls._token = TOKEN
|
||||||
|
set_access_token(TOKEN)
|
||||||
|
HfFolder.save_token(TOKEN)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-generation-config")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="valid_org/test-generation-config-org")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_push_to_hub(self):
|
||||||
|
config = GenerationConfig(
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.7,
|
||||||
|
length_penalty=1.0,
|
||||||
|
)
|
||||||
|
config.push_to_hub("test-generation-config", use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config")
|
||||||
|
for k, v in config.to_dict().items():
|
||||||
|
if k != "transformers_version":
|
||||||
|
self.assertEqual(v, getattr(new_config, k))
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="test-generation-config")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
config.save_pretrained(
|
||||||
|
tmp_dir, repo_id="test-generation-config", push_to_hub=True, use_auth_token=self._token
|
||||||
|
)
|
||||||
|
|
||||||
|
new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config")
|
||||||
|
for k, v in config.to_dict().items():
|
||||||
|
if k != "transformers_version":
|
||||||
|
self.assertEqual(v, getattr(new_config, k))
|
||||||
|
|
||||||
|
def test_push_to_hub_in_organization(self):
|
||||||
|
config = GenerationConfig(
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.7,
|
||||||
|
length_penalty=1.0,
|
||||||
|
)
|
||||||
|
config.push_to_hub("valid_org/test-generation-config-org", use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org")
|
||||||
|
for k, v in config.to_dict().items():
|
||||||
|
if k != "transformers_version":
|
||||||
|
self.assertEqual(v, getattr(new_config, k))
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="valid_org/test-generation-config-org")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
config.save_pretrained(
|
||||||
|
tmp_dir, repo_id="valid_org/test-generation-config-org", push_to_hub=True, use_auth_token=self._token
|
||||||
|
)
|
||||||
|
|
||||||
|
new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org")
|
||||||
|
for k, v in config.to_dict().items():
|
||||||
|
if k != "transformers_version":
|
||||||
|
self.assertEqual(v, getattr(new_config, k))
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
|
CONFIG_NAME,
|
||||||
|
GENERATION_CONFIG_NAME,
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
SAFE_WEIGHTS_NAME,
|
SAFE_WEIGHTS_NAME,
|
||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
@@ -275,6 +277,13 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
# the config file (and the generation config file, if it can generate) should be saved
|
||||||
|
self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
|
||||||
|
self.assertEqual(
|
||||||
|
model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
|
||||||
|
)
|
||||||
|
|
||||||
model = model_class.from_pretrained(tmpdirname)
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from transformers.testing_utils import (
|
|||||||
require_flax,
|
require_flax,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
|
||||||
from transformers.utils.generic import ModelOutput
|
from transformers.utils.generic import ModelOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -395,6 +395,13 @@ class FlaxModelTesterMixin:
|
|||||||
# verify that normal save_pretrained works as expected
|
# verify that normal save_pretrained works as expected
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
# the config file (and the generation config file, if it can generate) should be saved
|
||||||
|
self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
|
||||||
|
self.assertEqual(
|
||||||
|
model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
|
||||||
|
)
|
||||||
|
|
||||||
model_loaded = model_class.from_pretrained(tmpdirname)
|
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
|
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
|
||||||
|
|||||||
@@ -50,7 +50,14 @@ from transformers.testing_utils import ( # noqa: F401
|
|||||||
tooslow,
|
tooslow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
|
from transformers.utils import (
|
||||||
|
CONFIG_NAME,
|
||||||
|
GENERATION_CONFIG_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
TF2_WEIGHTS_INDEX_NAME,
|
||||||
|
TF2_WEIGHTS_NAME,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
from transformers.utils.generic import ModelOutput
|
from transformers.utils.generic import ModelOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -226,6 +233,13 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname, saved_model=False)
|
model.save_pretrained(tmpdirname, saved_model=False)
|
||||||
|
|
||||||
|
# the config file (and the generation config file, if it can generate) should be saved
|
||||||
|
self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
|
||||||
|
self.assertEqual(
|
||||||
|
model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
|
||||||
|
)
|
||||||
|
|
||||||
model = model_class.from_pretrained(tmpdirname)
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
after_outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
after_outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user