Trainer push to hub (#11328)
* Initial support for upload to hub * push -> upload * Fixes + examples * Fix torchhub test * Torchhub test I hate you * push_model_to_hub -> push_to_hub * Apply mixin to other pretrained models * Remove ABC inheritance * Add tests * Typo * Run tests * Install git-lfs * Change approach * Add push_to_hub to all * Staging test suite * Typo * Maybe like this? * More deps * Cache * Adapt name * Quality * MOAR tests * Put it in testing_utils * Docs + torchhub last hope * Styling * Wrong method * Typos * Update src/transformers/file_utils.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Address review comments * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Julien Chaumond <julien@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -17,6 +17,12 @@
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import BertConfig
|
||||
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
|
||||
|
||||
|
||||
class ConfigTester(object):
|
||||
@@ -81,3 +87,54 @@ class ConfigTester(object):
|
||||
self.create_and_test_config_from_and_save_pretrained()
|
||||
self.create_and_test_config_with_num_labels()
|
||||
self.check_config_can_be_init_without_params()
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||
cls._token = cls._api.login(username=USER, password=PASS)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-model")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model", use_auth_token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained(f"{USER}/test-model")
|
||||
for k, v in config.__dict__.items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(
|
||||
tmp_dir,
|
||||
push_to_hub=True,
|
||||
repo_name="test-model-org",
|
||||
use_auth_token=self._token,
|
||||
organization="valid_org",
|
||||
)
|
||||
|
||||
new_config = BertConfig.from_pretrained("valid_org/test-model-org")
|
||||
for k, v in config.__dict__.items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
Reference in New Issue
Block a user