From 1c66be80624c05e8a381990378f994ebedd9128f Mon Sep 17 00:00:00 2001 From: Lucain Date: Fri, 11 Oct 2024 15:06:15 +0200 Subject: [PATCH] Fix PushToHubMixin when pusing to a PR revision (#34090) --- src/transformers/utils/hub.py | 2 +- tests/generation/test_configuration_utils.py | 28 +++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 540c402944..54f763de33 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -802,7 +802,7 @@ class PushToHubMixin: CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file) ) - if revision is not None: + if revision is not None and not revision.startswith("refs/pr"): try: create_branch(repo_id=repo_id, branch=revision, token=token, exist_ok=True) except HfHubHTTPError as e: diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index 9c7f4db3c9..f4bd551bd7 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -20,7 +20,7 @@ import unittest import warnings from pathlib import Path -from huggingface_hub import HfFolder, delete_repo +from huggingface_hub import HfFolder, create_pull_request, create_repo, delete_repo from parameterized import parameterized from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available @@ -764,3 +764,29 @@ class ConfigPushToHubTester(unittest.TestCase): finally: # Always (try to) delete the repo. self._try_delete_repo(repo_id=tmp_repo, token=self._token) + + def test_push_to_hub_on_pr_revision(self): + with tempfile.TemporaryDirectory() as tmp_dir: + try: + # create a repo and a PR + repo_id = f"{USER}/test-generation-config-{Path(tmp_dir).name}" + create_repo(repo_id=repo_id, token=self._token) + pr = create_pull_request(repo_id=repo_id, title="Test PR", token=self._token) + revision = f"refs/pr/{pr.num}" + + # push to PR ref + config = GenerationConfig( + do_sample=True, + temperature=0.7, + length_penalty=1.0, + ) + config.push_to_hub(repo_id, token=self._token, revision=revision) + + # load from PR ref + new_config = GenerationConfig.from_pretrained(repo_id, revision=revision) + for k, v in config.to_dict().items(): + if k != "transformers_version": + self.assertEqual(v, getattr(new_config, k)) + finally: + # Always (try to) delete the repo. + self._try_delete_repo(repo_id=repo_id, token=self._token)