Fix PushToHubMixin when pusing to a PR revision (#34090)

This commit is contained in:
Lucain
2024-10-11 15:06:15 +02:00
committed by GitHub
parent 409dd2d19c
commit 1c66be8062
2 changed files with 28 additions and 2 deletions

View File

@@ -20,7 +20,7 @@ import unittest
import warnings
from pathlib import Path
from huggingface_hub import HfFolder, delete_repo
from huggingface_hub import HfFolder, create_pull_request, create_repo, delete_repo
from parameterized import parameterized
from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available
@@ -764,3 +764,29 @@ class ConfigPushToHubTester(unittest.TestCase):
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_on_pr_revision(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
# create a repo and a PR
repo_id = f"{USER}/test-generation-config-{Path(tmp_dir).name}"
create_repo(repo_id=repo_id, token=self._token)
pr = create_pull_request(repo_id=repo_id, title="Test PR", token=self._token)
revision = f"refs/pr/{pr.num}"
# push to PR ref
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub(repo_id, token=self._token, revision=revision)
# load from PR ref
new_config = GenerationConfig.from_pretrained(repo_id, revision=revision)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=repo_id, token=self._token)