Update Transformers to huggingface_hub >= 0.1.0 (#14251)

* Update Transformers to huggingface_hub >= 0.1.0

* Forgot to save...

* Style

* Fix test
This commit is contained in:
Sylvain Gugger
2021-11-02 18:58:42 -04:00
committed by GitHub
parent 519a677e87
commit 558f8543ba
15 changed files with 70 additions and 172 deletions

View File

@@ -103,8 +103,8 @@ Here is the code to see all available pretrained models on the hub:
.. code-block:: python .. code-block:: python
from huggingface_hub.hf_api import HfApi from huggingface_hub import list_models
model_list = HfApi().list_models() model_list = list_models()
org = "Helsinki-NLP" org = "Helsinki-NLP"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)] model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
suffix = [x.split('/')[1] for x in model_ids] suffix = [x.split('/')[1] for x in model_ids]

View File

@@ -14,7 +14,7 @@ import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf from convert_pl_checkpoint_to_hf import convert_pl_to_hf
from distillation import distill_main from distillation import distill_main
from finetune import SummarizationModule, main from finetune import SummarizationModule, main
from huggingface_hub.hf_api import HfApi from huggingface_hub import list_models
from parameterized import parameterized from parameterized import parameterized
from run_eval import generate_summaries_or_translations from run_eval import generate_summaries_or_translations
from transformers import AutoConfig, AutoModelForSeq2SeqLM from transformers import AutoConfig, AutoModelForSeq2SeqLM
@@ -130,7 +130,7 @@ class TestSummarizationDistiller(TestCasePlus):
def test_hub_configs(self): def test_hub_configs(self):
"""I put require_torch_gpu cause I only want this to run with self-scheduled.""" """I put require_torch_gpu cause I only want this to run with self-scheduled."""
model_list = HfApi().list_models() model_list = list_models()
org = "sshleifer" org = "sshleifer"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)] model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"] allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"]

View File

@@ -100,7 +100,7 @@ _deps = [
"flax>=0.3.4", "flax>=0.3.4",
"fugashi>=1.0", "fugashi>=1.0",
"GitPython<3.1.19", "GitPython<3.1.19",
"huggingface-hub>=0.0.17", "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata", "importlib_metadata",
"ipadic>=1.0.0,<2.0", "ipadic>=1.0.0,<2.0",
"isort>=5.5.4", "isort>=5.5.4",

View File

@@ -12,14 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import subprocess import subprocess
import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from getpass import getpass from getpass import getpass
from typing import List, Union from typing import List, Union
from huggingface_hub.hf_api import HfApi, HfFolder from huggingface_hub.hf_api import HfFolder, create_repo, list_repos_objs, login, logout, whoami
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from . import BaseTransformersCLICommand from . import BaseTransformersCLICommand
@@ -142,7 +140,6 @@ def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
class BaseUserCommand: class BaseUserCommand:
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
self._api = HfApi()
class LoginCommand(BaseUserCommand): class LoginCommand(BaseUserCommand):
@@ -166,7 +163,7 @@ class LoginCommand(BaseUserCommand):
username = input("Username: ") username = input("Username: ")
password = getpass() password = getpass()
try: try:
token = self._api.login(username, password) token = login(username, password)
except HTTPError as e: except HTTPError as e:
# probably invalid credentials, display error message. # probably invalid credentials, display error message.
print(e) print(e)
@@ -191,7 +188,7 @@ class WhoamiCommand(BaseUserCommand):
print("Not logged in") print("Not logged in")
exit() exit()
try: try:
user, orgs = self._api.whoami(token) user, orgs = whoami(token)
print(user) print(user)
if orgs: if orgs:
print(ANSI.bold("orgs: "), ",".join(orgs)) print(ANSI.bold("orgs: "), ",".join(orgs))
@@ -214,7 +211,7 @@ class LogoutCommand(BaseUserCommand):
print("Not logged in") print("Not logged in")
exit() exit()
HfFolder.delete_token() HfFolder.delete_token()
self._api.logout(token) logout(token)
print("Successfully logged out.") print("Successfully logged out.")
@@ -222,46 +219,24 @@ class ListObjsCommand(BaseUserCommand):
def run(self): def run(self):
print( print(
ANSI.red( ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. " "Command removed: it used to be the way to delete an object on S3."
"Please use `huggingface-cli` instead." " We now use a git-based system for storing models and other artifacts."
" Use list-repo-objs instead"
) )
) )
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1) exit(1)
try:
objs = self._api.list_objs(token, organization=self.args.organization)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
if len(objs) == 0:
print("No shared file yet")
exit()
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
print(tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
class DeleteObjCommand(BaseUserCommand): class DeleteObjCommand(BaseUserCommand):
def run(self): def run(self):
print( print(
ANSI.red( ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. " "Command removed: it used to be the way to delete an object on S3."
"Please use `huggingface-cli` instead." " We now use a git-based system for storing models and other artifacts."
" Use delete-repo instead"
) )
) )
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1) exit(1)
try:
self._api.delete_obj(token, filename=self.args.filename, organization=self.args.organization)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
print("Done")
class ListReposObjsCommand(BaseUserCommand): class ListReposObjsCommand(BaseUserCommand):
@@ -277,7 +252,7 @@ class ListReposObjsCommand(BaseUserCommand):
print("Not logged in") print("Not logged in")
exit(1) exit(1)
try: try:
objs = self._api.list_repos_objs(token, organization=self.args.organization) objs = list_repos_objs(token, organization=self.args.organization)
except HTTPError as e: except HTTPError as e:
print(e) print(e)
print(ANSI.red(e.response.text)) print(ANSI.red(e.response.text))
@@ -320,7 +295,7 @@ class RepoCreateCommand(BaseUserCommand):
) )
print("") print("")
user, _ = self._api.whoami(token) user, _ = whoami(token)
namespace = self.args.organization if self.args.organization is not None else user namespace = self.args.organization if self.args.organization is not None else user
full_name = f"{namespace}/{self.args.name}" full_name = f"{namespace}/{self.args.name}"
print(f"You are about to create {ANSI.bold(full_name)}") print(f"You are about to create {ANSI.bold(full_name)}")
@@ -331,7 +306,7 @@ class RepoCreateCommand(BaseUserCommand):
print("Abort") print("Abort")
exit() exit()
try: try:
url = self._api.create_repo(token, name=self.args.name, organization=self.args.organization) url = create_repo(token, name=self.args.name, organization=self.args.organization)
except HTTPError as e: except HTTPError as e:
print(e) print(e)
print(ANSI.red(e.response.text)) print(ANSI.red(e.response.text))
@@ -356,73 +331,12 @@ class DeprecatedUploadCommand(BaseUserCommand):
class UploadCommand(BaseUserCommand): class UploadCommand(BaseUserCommand):
def walk_dir(self, rel_path):
"""
Recursively list all files in a folder.
"""
entries: List[os.DirEntry] = list(os.scandir(rel_path))
files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # (filepath, filename)
for f in entries:
if f.is_dir():
files += self.walk_dir(f.path)
return files
def run(self): def run(self):
print( print(
ANSI.red( ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. " "Deprecated: used to be the way to upload a model to S3."
"Please use `huggingface-cli` instead." " We now use a git-based system for storing models and other artifacts."
" Use the `repo create` command instead."
) )
) )
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1)
local_path = os.path.abspath(self.args.path)
if os.path.isdir(local_path):
if self.args.filename is not None:
raise ValueError("Cannot specify a filename override when uploading a folder.")
rel_path = os.path.basename(local_path)
files = self.walk_dir(rel_path)
elif os.path.isfile(local_path):
filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path)
files = [(local_path, filename)]
else:
raise ValueError(f"Not a valid file or directory: {local_path}")
if sys.platform == "win32":
files = [(filepath, filename.replace(os.sep, "/")) for filepath, filename in files]
if len(files) > UPLOAD_MAX_FILES:
print(
f"About to upload {ANSI.bold(len(files))} files to S3. This is probably wrong. Please filter files "
"before uploading."
)
exit(1) exit(1)
user, _ = self._api.whoami(token)
namespace = self.args.organization if self.args.organization is not None else user
for filepath, filename in files:
print(
f"About to upload file {ANSI.bold(filepath)} to S3 under filename {ANSI.bold(filename)} and namespace "
f"{ANSI.bold(namespace)}"
)
if not self.args.yes:
choice = input("Proceed? [Y/n] ").lower()
if not (choice == "" or choice == "y" or choice == "yes"):
print("Abort")
exit()
print(ANSI.bold("Uploading... This might take a while if files are large"))
for filepath, filename in files:
try:
access_url = self._api.presign_and_upload(
token=token, filename=filename, filepath=filepath, organization=self.args.organization
)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
print("Your file now lives at:")
print(access_url)

View File

@@ -18,7 +18,7 @@ deps = {
"flax": "flax>=0.3.4", "flax": "flax>=0.3.4",
"fugashi": "fugashi>=1.0", "fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19", "GitPython": "GitPython<3.1.19",
"huggingface-hub": "huggingface-hub>=0.0.17", "huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata": "importlib_metadata", "importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0", "ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",

View File

@@ -48,7 +48,7 @@ from tqdm.auto import tqdm
import requests import requests
from filelock import FileLock from filelock import FileLock
from huggingface_hub import HfApi, HfFolder, Repository from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
from transformers.utils.versions import importlib_metadata from transformers.utils.versions import importlib_metadata
from . import __version__ from . import __version__
@@ -1808,17 +1808,14 @@ def get_list_of_files(
if is_offline_mode() or local_files_only: if is_offline_mode() or local_files_only:
return [] return []
# Otherwise we grab the token and use the model_info method. # Otherwise we grab the token and use the list_repo_files method.
if isinstance(use_auth_token, str): if isinstance(use_auth_token, str):
token = use_auth_token token = use_auth_token
elif use_auth_token is True: elif use_auth_token is True:
token = HfFolder.get_token() token = HfFolder.get_token()
else: else:
token = None token = None
model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info( return list_repo_files(path_or_repo, revision=revision, token=token)
path_or_repo, revision=revision, token=token
)
return [f.rfilename for f in model_info.siblings]
class cached_property(property): class cached_property(property):
@@ -2308,7 +2305,7 @@ class PushToHubMixin:
token = None token = None
# Special provision for the test endpoint (CI) # Special provision for the test endpoint (CI)
return HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).create_repo( return create_repo(
token, token,
repo_name, repo_name,
organization=organization, organization=organization,
@@ -2366,7 +2363,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
if token is None: if token is None:
token = HfFolder.get_token() token = HfFolder.get_token()
if organization is None: if organization is None:
username = HfApi().whoami(token)["name"] username = whoami(token)["name"]
return f"{username}/{model_id}" return f"{username}/{model_id}"
else: else:
return f"{organization}/{model_id}" return f"{organization}/{model_id}"

View File

@@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Union
import requests import requests
import yaml import yaml
from huggingface_hub import HfApi from huggingface_hub import model_info
from . import __version__ from . import __version__
from .file_utils import ( from .file_utils import (
@@ -387,8 +387,8 @@ class TrainingSummary:
and len(self.finetuned_from) > 0 and len(self.finetuned_from) > 0
): ):
try: try:
model_info = HfApi().model_info(self.finetuned_from) info = model_info(self.finetuned_from)
for tag in model_info.tags: for tag in info.tags:
if tag.startswith("license:"): if tag.startswith("license:"):
self.license = tag[8:] self.license = tag[8:]
except requests.exceptions.HTTPError: except requests.exceptions.HTTPError:

View File

@@ -27,7 +27,7 @@ import torch
from torch import nn from torch import nn
from tqdm import tqdm from tqdm import tqdm
from huggingface_hub.hf_api import HfApi from huggingface_hub.hf_api import list_models
from transformers import MarianConfig, MarianMTModel, MarianTokenizer from transformers import MarianConfig, MarianMTModel, MarianTokenizer
@@ -64,8 +64,7 @@ def load_layers_(layer_lst: nn.ModuleList, opus_state: dict, converter, is_decod
def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]: def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]:
"""Find models that can accept src_lang as input and return tgt_lang as output.""" """Find models that can accept src_lang as input and return tgt_lang as output."""
prefix = "Helsinki-NLP/opus-mt-" prefix = "Helsinki-NLP/opus-mt-"
api = HfApi() model_list = list_models()
model_list = api.list_models()
model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")] model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")]
src_and_targ = [ src_and_targ = [
remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m

View File

@@ -19,11 +19,11 @@ import os
import tempfile import tempfile
import unittest import unittest
from huggingface_hub import HfApi from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import BertConfig, GPT2Config, is_torch_available from transformers import BertConfig, GPT2Config, is_torch_available
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test from transformers.testing_utils import PASS, USER, is_staging_test
config_common_kwargs = { config_common_kwargs = {
@@ -194,18 +194,17 @@ class ConfigTester(object):
class ConfigPushToHubTester(unittest.TestCase): class ConfigPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = login(username=USER, password=PASS)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try: try:
cls._api.delete_repo(token=cls._token, name="test-config") delete_repo(token=cls._token, name="test-config")
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-config-org", organization="valid_org") delete_repo(token=cls._token, name="test-config-org", organization="valid_org")
except HTTPError: except HTTPError:
pass pass

View File

@@ -28,13 +28,12 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import transformers import transformers
from huggingface_hub import HfApi, Repository from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import ( from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS, PASS,
USER, USER,
CaptureLogger, CaptureLogger,
@@ -2122,23 +2121,22 @@ class FakeModel(PreTrainedModel):
class ModelPushToHubTester(unittest.TestCase): class ModelPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = login(username=USER, password=PASS)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try: try:
cls._api.delete_repo(token=cls._token, name="test-model") delete_repo(token=cls._token, name="test-model")
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org") delete_repo(token=cls._token, name="test-model-org", organization="valid_org")
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-dynamic-model") delete_repo(token=cls._token, name="test-dynamic-model")
except HTTPError: except HTTPError:
pass pass

View File

@@ -22,19 +22,11 @@ from typing import List, Tuple
import numpy as np import numpy as np
import transformers import transformers
from huggingface_hub import HfApi from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import BertConfig, is_flax_available, is_torch_available from transformers import BertConfig, is_flax_available, is_torch_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import ( from transformers.testing_utils import PASS, USER, CaptureLogger, is_pt_flax_cross_test, is_staging_test, require_flax
ENDPOINT_STAGING,
PASS,
USER,
CaptureLogger,
is_pt_flax_cross_test,
is_staging_test,
require_flax,
)
from transformers.utils import logging from transformers.utils import logging
@@ -627,18 +619,17 @@ class FlaxModelTesterMixin:
class FlaxModelPushToHubTester(unittest.TestCase): class FlaxModelPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = login(username=USER, password=PASS)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try: try:
cls._api.delete_repo(token=cls._token, name="test-model-flax") delete_repo(token=cls._token, name="test-model-flax")
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org") delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org")
except HTTPError: except HTTPError:
pass pass

View File

@@ -17,7 +17,7 @@
import tempfile import tempfile
import unittest import unittest
from huggingface_hub.hf_api import HfApi from huggingface_hub.hf_api import list_models
from transformers import MarianConfig, is_torch_available from transformers import MarianConfig, is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
@@ -296,7 +296,7 @@ class ModelManagementTests(unittest.TestCase):
@slow @slow
@require_torch @require_torch
def test_model_names(self): def test_model_names(self):
model_list = HfApi().list_models() model_list = list_models()
model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)] model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)]
bad_model_ids = [mid for mid in model_ids if "+" in model_ids] bad_model_ids = [mid for mid in model_ids if "+" in model_ids]
self.assertListEqual([], bad_model_ids) self.assertListEqual([], bad_model_ids)

View File

@@ -24,12 +24,11 @@ import unittest
from importlib import import_module from importlib import import_module
from typing import List, Tuple from typing import List, Tuple
from huggingface_hub import HfApi from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import is_tf_available from transformers import is_tf_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import ( from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS, PASS,
USER, USER,
CaptureLogger, CaptureLogger,
@@ -1530,18 +1529,17 @@ class UtilsFunctionsTest(unittest.TestCase):
class TFModelPushToHubTester(unittest.TestCase): class TFModelPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = login(username=USER, password=PASS)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try: try:
cls._api.delete_repo(token=cls._token, name="test-model-tf") delete_repo(token=cls._token, name="test-model-tf")
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-model-tf-org", organization="valid_org") delete_repo(token=cls._token, name="test-model-tf-org", organization="valid_org")
except HTTPError: except HTTPError:
pass pass

View File

@@ -27,7 +27,7 @@ from collections import OrderedDict
from itertools import takewhile from itertools import takewhile
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from huggingface_hub import HfApi from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import ( from transformers import (
AlbertTokenizer, AlbertTokenizer,
@@ -44,7 +44,6 @@ from transformers import (
is_torch_available, is_torch_available,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS, PASS,
USER, USER,
get_tests_dir, get_tests_dir,
@@ -3520,18 +3519,17 @@ class TokenizerPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = login(username=USER, password=PASS)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try: try:
cls._api.delete_repo(token=cls._token, name="test-tokenizer") delete_repo(token=cls._token, name="test-tokenizer")
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-tokenizer-org", organization="valid_org") delete_repo(token=cls._token, name="test-tokenizer-org", organization="valid_org")
except HTTPError: except HTTPError:
pass pass

View File

@@ -26,7 +26,7 @@ from pathlib import Path
import numpy as np import numpy as np
from huggingface_hub import HfApi, Repository from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
@@ -1307,19 +1307,18 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
class TrainerIntegrationWithHubTester(unittest.TestCase): class TrainerIntegrationWithHubTester(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = login(username=USER, password=PASS)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step"]: for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step"]:
try: try:
cls._api.delete_repo(token=cls._token, name=model) delete_repo(token=cls._token, name=model)
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-trainer-org", organization="valid_org") delete_repo(token=cls._token, name="test-trainer-org", organization="valid_org")
except HTTPError: except HTTPError:
pass pass
@@ -1396,6 +1395,10 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
print(commits, len(commits)) print(commits, len(commits))
def test_push_to_hub_with_saves_each_n_steps(self): def test_push_to_hub_with_saves_each_n_steps(self):
num_gpus = max(1, get_gpu_count())
if num_gpus > 2:
return
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer( trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-step"), output_dir=os.path.join(tmp_dir, "test-trainer-step"),
@@ -1409,7 +1412,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token) _ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token)
commits = self.get_commit_history(tmp_dir) commits = self.get_commit_history(tmp_dir)
expected_commits = [f"Training in progress, step {i}" for i in range(20, 0, -5)] total_steps = 20 // num_gpus
expected_commits = [f"Training in progress, step {i}" for i in range(total_steps, 0, -5)]
expected_commits.append("initial commit") expected_commits.append("initial commit")
self.assertListEqual(commits, expected_commits) self.assertListEqual(commits, expected_commits)
print(commits, len(commits)) print(commits, len(commits))