Add revision to trainer push_to_hub (#33482)

* add revision to trainer push_to_hub

* apply suggestions

* add test for revision

* apply ruff format

* reorganize imports

* change test trainer path
This commit is contained in:
teamclouday
2024-09-17 17:11:32 -04:00
committed by GitHub
parent d8500cd229
commit 6c051b4e1e
2 changed files with 24 additions and 1 deletions

View File

@@ -31,7 +31,7 @@ from typing import Dict, List
from unittest.mock import Mock, patch
import numpy as np
from huggingface_hub import HfFolder, ModelCard, delete_repo, list_repo_commits, list_repo_files
from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files
from parameterized import parameterized
from requests.exceptions import HTTPError
@@ -3933,6 +3933,25 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
model_card = ModelCard.load(repo_name)
self.assertTrue("test-trainer-tags" in model_card.data.tags)
def test_push_to_hub_with_revision(self):
# Checks if `trainer.push_to_hub()` works correctly by adding revision
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-revision"),
push_to_hub=True,
hub_token=self._token,
)
branch = "v1.0"
create_branch(repo_id=trainer.hub_model_id, branch=branch, token=self._token, exist_ok=True)
url = trainer.push_to_hub(revision=branch)
# Extract branch from the url
re_search = re.search(r"tree/([^/]+)/", url)
self.assertIsNotNone(re_search)
branch_name = re_search.groups()[0]
self.assertEqual(branch_name, branch)
@require_torch
@require_optuna