Trainer push to hub (#11328)
* Initial support for upload to hub * push -> upload * Fixes + examples * Fix torchhub test * Torchhub test I hate you * push_model_to_hub -> push_to_hub * Apply mixin to other pretrained models * Remove ABC inheritance * Add tests * Typo * Run tests * Install git-lfs * Change approach * Add push_to_hub to all * Staging test suite * Typo * Maybe like this? * More deps * Cache * Adapt name * Quality * MOAR tests * Put it in testing_utils * Docs + torchhub last hope * Styling * Wrong method * Typos * Update src/transformers/file_utils.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Address review comments * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Julien Chaumond <julien@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -16,16 +16,23 @@
|
||||
import dataclasses
|
||||
import gc
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
PASS,
|
||||
USER,
|
||||
TestCasePlus,
|
||||
get_tests_dir,
|
||||
is_staging_test,
|
||||
require_datasets,
|
||||
require_optuna,
|
||||
require_ray,
|
||||
@@ -1081,6 +1088,60 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
class TrainerIntegrationWithHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||
cls._token = cls._api.login(username=USER, password=PASS)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-model")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = get_regression_trainer(output_dir=tmp_dir)
|
||||
trainer.save_model()
|
||||
url = trainer.push_to_hub(repo_name="test-model", use_auth_token=self._token)
|
||||
|
||||
# Extract repo_name from the url
|
||||
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
|
||||
self.assertTrue(re_search is not None)
|
||||
repo_name = re_search.groups()[0]
|
||||
|
||||
self.assertEqual(repo_name, f"{USER}/test-model")
|
||||
|
||||
model = RegressionPreTrainedModel.from_pretrained(repo_name)
|
||||
self.assertEqual(model.a.item(), trainer.model.a.item())
|
||||
self.assertEqual(model.b.item(), trainer.model.b.item())
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = get_regression_trainer(output_dir=tmp_dir)
|
||||
trainer.save_model()
|
||||
url = trainer.push_to_hub(repo_name="test-model-org", organization="valid_org", use_auth_token=self._token)
|
||||
|
||||
# Extract repo_name from the url
|
||||
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
|
||||
self.assertTrue(re_search is not None)
|
||||
repo_name = re_search.groups()[0]
|
||||
self.assertEqual(repo_name, "valid_org/test-model-org")
|
||||
|
||||
model = RegressionPreTrainedModel.from_pretrained("valid_org/test-model-org")
|
||||
self.assertEqual(model.a.item(), trainer.model.a.item())
|
||||
self.assertEqual(model.b.item(), trainer.model.b.item())
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_optuna
|
||||
class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user