FIX [Trainer / tags]: Fix trainer + tags when users do not pass "tags" to trainer.push_to_hub() (#29009)
* fix trainer tags * add test
This commit is contained in:
@@ -3842,7 +3842,10 @@ class Trainer:
|
|||||||
# Add additional tags in the case the model has already some tags and users pass
|
# Add additional tags in the case the model has already some tags and users pass
|
||||||
# "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
|
# "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
|
||||||
# from all models since Trainer does not call `model.push_to_hub`.
|
# from all models since Trainer does not call `model.push_to_hub`.
|
||||||
if "tags" in kwargs and getattr(self.model, "model_tags", None) is not None:
|
if getattr(self.model, "model_tags", None) is not None:
|
||||||
|
if "tags" not in kwargs:
|
||||||
|
kwargs["tags"] = []
|
||||||
|
|
||||||
# If it is a string, convert it to a list
|
# If it is a string, convert it to a list
|
||||||
if isinstance(kwargs["tags"], str):
|
if isinstance(kwargs["tags"], str):
|
||||||
kwargs["tags"] = [kwargs["tags"]]
|
kwargs["tags"] = [kwargs["tags"]]
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from typing import Dict, List
|
|||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import HfFolder, delete_repo, list_repo_commits, list_repo_files
|
from huggingface_hub import HfFolder, ModelCard, delete_repo, list_repo_commits, list_repo_files
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
@@ -2564,7 +2564,13 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step", "test-trainer-tensorboard"]:
|
for model in [
|
||||||
|
"test-trainer",
|
||||||
|
"test-trainer-epoch",
|
||||||
|
"test-trainer-step",
|
||||||
|
"test-trainer-tensorboard",
|
||||||
|
"test-trainer-tags",
|
||||||
|
]:
|
||||||
try:
|
try:
|
||||||
delete_repo(token=cls._token, repo_id=model)
|
delete_repo(token=cls._token, repo_id=model)
|
||||||
except HTTPError:
|
except HTTPError:
|
||||||
@@ -2695,6 +2701,31 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||||||
|
|
||||||
assert found_log is True, "No tensorboard log found in repo"
|
assert found_log is True, "No tensorboard log found in repo"
|
||||||
|
|
||||||
|
def test_push_to_hub_tags(self):
|
||||||
|
# Checks if `trainer.push_to_hub()` works correctly by adding the desired
|
||||||
|
# tag without having to pass `tags` in `push_to_hub`
|
||||||
|
# see:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
output_dir=os.path.join(tmp_dir, "test-trainer-tags"),
|
||||||
|
push_to_hub=True,
|
||||||
|
hub_token=self._token,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.model.add_model_tags(["test-trainer-tags"])
|
||||||
|
|
||||||
|
url = trainer.push_to_hub()
|
||||||
|
|
||||||
|
# Extract repo_name from the url
|
||||||
|
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
|
||||||
|
self.assertTrue(re_search is not None)
|
||||||
|
repo_name = re_search.groups()[0]
|
||||||
|
|
||||||
|
self.assertEqual(repo_name, f"{USER}/test-trainer-tags")
|
||||||
|
|
||||||
|
model_card = ModelCard.load(repo_name)
|
||||||
|
self.assertTrue("test-trainer-tags" in model_card.data.tags)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_optuna
|
@require_optuna
|
||||||
|
|||||||
Reference in New Issue
Block a user