Add revision to trainer push_to_hub (#33482)
* add revision to trainer push_to_hub * apply suggestions * add test for revision * apply ruff format * reorganize imports * change test trainer path
This commit is contained in:
@@ -4461,6 +4461,7 @@ class Trainer:
|
|||||||
commit_message: Optional[str] = "End of training",
|
commit_message: Optional[str] = "End of training",
|
||||||
blocking: bool = True,
|
blocking: bool = True,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -4473,6 +4474,8 @@ class Trainer:
|
|||||||
Whether the function should return only when the `git push` has finished.
|
Whether the function should return only when the `git push` has finished.
|
||||||
token (`str`, *optional*, defaults to `None`):
|
token (`str`, *optional*, defaults to `None`):
|
||||||
Token with write permission to overwrite Trainer's original args.
|
Token with write permission to overwrite Trainer's original args.
|
||||||
|
revision (`str`, *optional*):
|
||||||
|
The git revision to commit from. Defaults to the head of the "main" branch.
|
||||||
kwargs (`Dict[str, Any]`, *optional*):
|
kwargs (`Dict[str, Any]`, *optional*):
|
||||||
Additional keyword arguments passed along to [`~Trainer.create_model_card`].
|
Additional keyword arguments passed along to [`~Trainer.create_model_card`].
|
||||||
|
|
||||||
@@ -4526,6 +4529,7 @@ class Trainer:
|
|||||||
token=token,
|
token=token,
|
||||||
run_as_future=not blocking,
|
run_as_future=not blocking,
|
||||||
ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
|
ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
|
||||||
|
revision=revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from typing import Dict, List
|
|||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import HfFolder, ModelCard, delete_repo, list_repo_commits, list_repo_files
|
from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
@@ -3933,6 +3933,25 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||||||
model_card = ModelCard.load(repo_name)
|
model_card = ModelCard.load(repo_name)
|
||||||
self.assertTrue("test-trainer-tags" in model_card.data.tags)
|
self.assertTrue("test-trainer-tags" in model_card.data.tags)
|
||||||
|
|
||||||
|
def test_push_to_hub_with_revision(self):
|
||||||
|
# Checks if `trainer.push_to_hub()` works correctly by adding revision
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
output_dir=os.path.join(tmp_dir, "test-trainer-revision"),
|
||||||
|
push_to_hub=True,
|
||||||
|
hub_token=self._token,
|
||||||
|
)
|
||||||
|
branch = "v1.0"
|
||||||
|
create_branch(repo_id=trainer.hub_model_id, branch=branch, token=self._token, exist_ok=True)
|
||||||
|
url = trainer.push_to_hub(revision=branch)
|
||||||
|
|
||||||
|
# Extract branch from the url
|
||||||
|
re_search = re.search(r"tree/([^/]+)/", url)
|
||||||
|
self.assertIsNotNone(re_search)
|
||||||
|
|
||||||
|
branch_name = re_search.groups()[0]
|
||||||
|
self.assertEqual(branch_name, branch)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_optuna
|
@require_optuna
|
||||||
|
|||||||
Reference in New Issue
Block a user