Generate: save generation config with the models' .save_pretrained() (#21264)

This commit is contained in:
Joao Gante
2023-01-23 16:21:44 +00:00
committed by GitHub
parent cf1a1eed70
commit 1eda4a4102
7 changed files with 117 additions and 3 deletions

View File

@@ -17,11 +17,14 @@ import copy
import tempfile
import unittest
from huggingface_hub import HfFolder, delete_repo, set_access_token
from parameterized import parameterized
from requests.exceptions import HTTPError
from transformers import AutoConfig, GenerationConfig
from transformers.testing_utils import TOKEN, USER, is_staging_test
class LogitsProcessorTest(unittest.TestCase):
class GenerationConfigTest(unittest.TestCase):
@parameterized.expand([(None,), ("foo.json",)])
def test_save_load_config(self, config_name):
config = GenerationConfig(
@@ -74,3 +77,78 @@ class LogitsProcessorTest(unittest.TestCase):
# `.update()` returns a dictionary of unused kwargs
self.assertEqual(unused_kwargs, {"foo": "bar"})
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._token = TOKEN
set_access_token(TOKEN)
HfFolder.save_token(TOKEN)
@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, repo_id="test-generation-config")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-generation-config-org")
except HTTPError:
pass
def test_push_to_hub(self):
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub("test-generation-config", use_auth_token=self._token)
new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
# Reset repo
delete_repo(token=self._token, repo_id="test-generation-config")
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(
tmp_dir, repo_id="test-generation-config", push_to_hub=True, use_auth_token=self._token
)
new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_in_organization(self):
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub("valid_org/test-generation-config-org", use_auth_token=self._token)
new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-generation-config-org")
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(
tmp_dir, repo_id="valid_org/test-generation-config-org", push_to_hub=True, use_auth_token=self._token
)
new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))