Fix PushToHubMixin when pusing to a PR revision (#34090)

This commit is contained in:
Lucain
2024-10-11 15:06:15 +02:00
committed by GitHub
parent 409dd2d19c
commit 1c66be8062
2 changed files with 28 additions and 2 deletions

View File

@@ -802,7 +802,7 @@ class PushToHubMixin:
CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file) CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file)
) )
if revision is not None: if revision is not None and not revision.startswith("refs/pr"):
try: try:
create_branch(repo_id=repo_id, branch=revision, token=token, exist_ok=True) create_branch(repo_id=repo_id, branch=revision, token=token, exist_ok=True)
except HfHubHTTPError as e: except HfHubHTTPError as e:

View File

@@ -20,7 +20,7 @@ import unittest
import warnings import warnings
from pathlib import Path from pathlib import Path
from huggingface_hub import HfFolder, delete_repo from huggingface_hub import HfFolder, create_pull_request, create_repo, delete_repo
from parameterized import parameterized from parameterized import parameterized
from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available
@@ -764,3 +764,29 @@ class ConfigPushToHubTester(unittest.TestCase):
finally: finally:
# Always (try to) delete the repo. # Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token) self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_on_pr_revision(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
# create a repo and a PR
repo_id = f"{USER}/test-generation-config-{Path(tmp_dir).name}"
create_repo(repo_id=repo_id, token=self._token)
pr = create_pull_request(repo_id=repo_id, title="Test PR", token=self._token)
revision = f"refs/pr/{pr.num}"
# push to PR ref
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub(repo_id, token=self._token, revision=revision)
# load from PR ref
new_config = GenerationConfig.from_pretrained(repo_id, revision=revision)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=repo_id, token=self._token)