Rely on huggingface_hub for common tools (#13100)
* Remove hf_api module and use hugginface_hub * Style * Fix to test_fetcher * Quality
This commit is contained in:
@@ -103,8 +103,8 @@ Here is the code to see all available pretrained models on the hub:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from transformers.hf_api import HfApi
|
from huggingface_hub.hf_api import HfApi
|
||||||
model_list = HfApi().model_list()
|
model_list = HfApi().list_models()
|
||||||
org = "Helsinki-NLP"
|
org = "Helsinki-NLP"
|
||||||
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
|
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
|
||||||
suffix = [x.split('/')[1] for x in model_ids]
|
suffix = [x.split('/')[1] for x in model_ids]
|
||||||
|
|||||||
@@ -14,10 +14,10 @@ import lightning_base
|
|||||||
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||||
from distillation import distill_main
|
from distillation import distill_main
|
||||||
from finetune import SummarizationModule, main
|
from finetune import SummarizationModule, main
|
||||||
|
from huggingface_hub.hf_api import HfApi
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from run_eval import generate_summaries_or_translations
|
from run_eval import generate_summaries_or_translations
|
||||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
||||||
from transformers.hf_api import HfApi
|
|
||||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow
|
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow
|
||||||
from utils import label_smoothed_nll_loss, lmap, load_json
|
from utils import label_smoothed_nll_loss, lmap, load_json
|
||||||
|
|
||||||
@@ -130,7 +130,7 @@ class TestSummarizationDistiller(TestCasePlus):
|
|||||||
def test_hub_configs(self):
|
def test_hub_configs(self):
|
||||||
"""I put require_torch_gpu cause I only want this to run with self-scheduled."""
|
"""I put require_torch_gpu cause I only want this to run with self-scheduled."""
|
||||||
|
|
||||||
model_list = HfApi().model_list()
|
model_list = HfApi().list_models()
|
||||||
org = "sshleifer"
|
org = "sshleifer"
|
||||||
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
|
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
|
||||||
allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"]
|
allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"]
|
||||||
|
|||||||
@@ -15,14 +15,13 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from getpass import getpass
|
from getpass import getpass
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
from huggingface_hub.hf_api import HfApi, HfFolder
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
from ..hf_api import HfApi, HfFolder
|
|
||||||
from . import BaseTransformersCLICommand
|
from . import BaseTransformersCLICommand
|
||||||
|
|
||||||
|
|
||||||
@@ -148,6 +147,12 @@ class BaseUserCommand:
|
|||||||
|
|
||||||
class LoginCommand(BaseUserCommand):
|
class LoginCommand(BaseUserCommand):
|
||||||
def run(self):
|
def run(self):
|
||||||
|
print(
|
||||||
|
ANSI.red(
|
||||||
|
"WARNING! `transformers-cli login` is deprecated and will be removed in v5. Please use "
|
||||||
|
"`huggingface-cli login` instead."
|
||||||
|
)
|
||||||
|
)
|
||||||
print( # docstyle-ignore
|
print( # docstyle-ignore
|
||||||
"""
|
"""
|
||||||
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
|
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
|
||||||
@@ -175,6 +180,12 @@ class LoginCommand(BaseUserCommand):
|
|||||||
|
|
||||||
class WhoamiCommand(BaseUserCommand):
|
class WhoamiCommand(BaseUserCommand):
|
||||||
def run(self):
|
def run(self):
|
||||||
|
print(
|
||||||
|
ANSI.red(
|
||||||
|
"WARNING! `transformers-cli whoami` is deprecated and will be removed in v5. Please use "
|
||||||
|
"`huggingface-cli whoami` instead."
|
||||||
|
)
|
||||||
|
)
|
||||||
token = HfFolder.get_token()
|
token = HfFolder.get_token()
|
||||||
if token is None:
|
if token is None:
|
||||||
print("Not logged in")
|
print("Not logged in")
|
||||||
@@ -192,6 +203,12 @@ class WhoamiCommand(BaseUserCommand):
|
|||||||
|
|
||||||
class LogoutCommand(BaseUserCommand):
|
class LogoutCommand(BaseUserCommand):
|
||||||
def run(self):
|
def run(self):
|
||||||
|
print(
|
||||||
|
ANSI.red(
|
||||||
|
"WARNING! `transformers-cli logout` is deprecated and will be removed in v5. Please use "
|
||||||
|
"`huggingface-cli logout` instead."
|
||||||
|
)
|
||||||
|
)
|
||||||
token = HfFolder.get_token()
|
token = HfFolder.get_token()
|
||||||
if token is None:
|
if token is None:
|
||||||
print("Not logged in")
|
print("Not logged in")
|
||||||
@@ -203,8 +220,11 @@ class LogoutCommand(BaseUserCommand):
|
|||||||
|
|
||||||
class ListObjsCommand(BaseUserCommand):
|
class ListObjsCommand(BaseUserCommand):
|
||||||
def run(self):
|
def run(self):
|
||||||
warnings.warn(
|
print(
|
||||||
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
|
ANSI.red(
|
||||||
|
"WARNING! Managing repositories through transformers-cli is deprecated. "
|
||||||
|
"Please use `huggingface-cli` instead."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
token = HfFolder.get_token()
|
token = HfFolder.get_token()
|
||||||
if token is None:
|
if token is None:
|
||||||
@@ -225,8 +245,11 @@ class ListObjsCommand(BaseUserCommand):
|
|||||||
|
|
||||||
class DeleteObjCommand(BaseUserCommand):
|
class DeleteObjCommand(BaseUserCommand):
|
||||||
def run(self):
|
def run(self):
|
||||||
warnings.warn(
|
print(
|
||||||
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
|
ANSI.red(
|
||||||
|
"WARNING! Managing repositories through transformers-cli is deprecated. "
|
||||||
|
"Please use `huggingface-cli` instead."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
token = HfFolder.get_token()
|
token = HfFolder.get_token()
|
||||||
if token is None:
|
if token is None:
|
||||||
@@ -243,8 +266,11 @@ class DeleteObjCommand(BaseUserCommand):
|
|||||||
|
|
||||||
class ListReposObjsCommand(BaseUserCommand):
|
class ListReposObjsCommand(BaseUserCommand):
|
||||||
def run(self):
|
def run(self):
|
||||||
warnings.warn(
|
print(
|
||||||
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
|
ANSI.red(
|
||||||
|
"WARNING! Managing repositories through transformers-cli is deprecated. "
|
||||||
|
"Please use `huggingface-cli` instead."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
token = HfFolder.get_token()
|
token = HfFolder.get_token()
|
||||||
if token is None:
|
if token is None:
|
||||||
@@ -265,8 +291,11 @@ class ListReposObjsCommand(BaseUserCommand):
|
|||||||
|
|
||||||
class RepoCreateCommand(BaseUserCommand):
|
class RepoCreateCommand(BaseUserCommand):
|
||||||
def run(self):
|
def run(self):
|
||||||
warnings.warn(
|
print(
|
||||||
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
|
ANSI.red(
|
||||||
|
"WARNING! Managing repositories through transformers-cli is deprecated. "
|
||||||
|
"Please use `huggingface-cli` instead."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
token = HfFolder.get_token()
|
token = HfFolder.get_token()
|
||||||
if token is None:
|
if token is None:
|
||||||
@@ -339,8 +368,11 @@ class UploadCommand(BaseUserCommand):
|
|||||||
return files
|
return files
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
warnings.warn(
|
print(
|
||||||
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
|
ANSI.red(
|
||||||
|
"WARNING! Managing repositories through transformers-cli is deprecated. "
|
||||||
|
"Please use `huggingface-cli` instead."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
token = HfFolder.get_token()
|
token = HfFolder.get_token()
|
||||||
if token is None:
|
if token is None:
|
||||||
|
|||||||
@@ -1,240 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
from os.path import expanduser
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
ENDPOINT = "https://huggingface.co"
|
|
||||||
|
|
||||||
|
|
||||||
class RepoObj:
|
|
||||||
"""
|
|
||||||
HuggingFace git-based system, data structure that represents a file belonging to the current user.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs):
|
|
||||||
self.filename = filename
|
|
||||||
self.lastModified = lastModified
|
|
||||||
self.commit = commit
|
|
||||||
self.size = size
|
|
||||||
|
|
||||||
|
|
||||||
class ModelSibling:
|
|
||||||
"""
|
|
||||||
Data structure that represents a public file inside a model, accessible from huggingface.co
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rfilename: str, **kwargs):
|
|
||||||
self.rfilename = rfilename # filename relative to the model root
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
setattr(self, k, v)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo:
|
|
||||||
"""
|
|
||||||
Info about a public model accessible from huggingface.co
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
modelId: Optional[str] = None, # id of model
|
|
||||||
tags: List[str] = [],
|
|
||||||
pipeline_tag: Optional[str] = None,
|
|
||||||
siblings: Optional[List[Dict]] = None, # list of files that constitute the model
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
self.modelId = modelId
|
|
||||||
self.tags = tags
|
|
||||||
self.pipeline_tag = pipeline_tag
|
|
||||||
self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
setattr(self, k, v)
|
|
||||||
|
|
||||||
|
|
||||||
class HfApi:
|
|
||||||
def __init__(self, endpoint=None):
|
|
||||||
self.endpoint = endpoint if endpoint is not None else ENDPOINT
|
|
||||||
|
|
||||||
def login(self, username: str, password: str) -> str:
|
|
||||||
"""
|
|
||||||
Call HF API to sign in a user and get a token if credentials are valid.
|
|
||||||
|
|
||||||
Outputs: token if credentials are valid
|
|
||||||
|
|
||||||
Throws: requests.exceptions.HTTPError if credentials are invalid
|
|
||||||
"""
|
|
||||||
path = f"{self.endpoint}/api/login"
|
|
||||||
r = requests.post(path, json={"username": username, "password": password})
|
|
||||||
r.raise_for_status()
|
|
||||||
d = r.json()
|
|
||||||
return d["token"]
|
|
||||||
|
|
||||||
def whoami(self, token: str) -> Tuple[str, List[str]]:
|
|
||||||
"""
|
|
||||||
Call HF API to know "whoami"
|
|
||||||
"""
|
|
||||||
path = f"{self.endpoint}/api/whoami"
|
|
||||||
r = requests.get(path, headers={"authorization": f"Bearer {token}"})
|
|
||||||
r.raise_for_status()
|
|
||||||
d = r.json()
|
|
||||||
return d["user"], d["orgs"]
|
|
||||||
|
|
||||||
def logout(self, token: str) -> None:
|
|
||||||
"""
|
|
||||||
Call HF API to log out.
|
|
||||||
"""
|
|
||||||
path = f"{self.endpoint}/api/logout"
|
|
||||||
r = requests.post(path, headers={"authorization": f"Bearer {token}"})
|
|
||||||
r.raise_for_status()
|
|
||||||
|
|
||||||
def model_list(self) -> List[ModelInfo]:
|
|
||||||
"""
|
|
||||||
Get the public list of all the models on huggingface.co
|
|
||||||
"""
|
|
||||||
path = f"{self.endpoint}/api/models"
|
|
||||||
r = requests.get(path)
|
|
||||||
r.raise_for_status()
|
|
||||||
d = r.json()
|
|
||||||
return [ModelInfo(**x) for x in d]
|
|
||||||
|
|
||||||
def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[RepoObj]:
|
|
||||||
"""
|
|
||||||
HuggingFace git-based system, used for models.
|
|
||||||
|
|
||||||
Call HF API to list all stored files for user (or one of their organizations).
|
|
||||||
"""
|
|
||||||
path = f"{self.endpoint}/api/repos/ls"
|
|
||||||
params = {"organization": organization} if organization is not None else None
|
|
||||||
r = requests.get(path, params=params, headers={"authorization": f"Bearer {token}"})
|
|
||||||
r.raise_for_status()
|
|
||||||
d = r.json()
|
|
||||||
return [RepoObj(**x) for x in d]
|
|
||||||
|
|
||||||
def create_repo(
|
|
||||||
self,
|
|
||||||
token: str,
|
|
||||||
name: str,
|
|
||||||
organization: Optional[str] = None,
|
|
||||||
private: Optional[bool] = None,
|
|
||||||
exist_ok=False,
|
|
||||||
lfsmultipartthresh: Optional[int] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
HuggingFace git-based system, used for models.
|
|
||||||
|
|
||||||
Call HF API to create a whole repo.
|
|
||||||
|
|
||||||
Params:
|
|
||||||
private: Whether the model repo should be private (requires a paid huggingface.co account)
|
|
||||||
|
|
||||||
exist_ok: Do not raise an error if repo already exists
|
|
||||||
|
|
||||||
lfsmultipartthresh: Optional: internal param for testing purposes.
|
|
||||||
"""
|
|
||||||
path = f"{self.endpoint}/api/repos/create"
|
|
||||||
json = {"name": name, "organization": organization, "private": private}
|
|
||||||
if lfsmultipartthresh is not None:
|
|
||||||
json["lfsmultipartthresh"] = lfsmultipartthresh
|
|
||||||
r = requests.post(
|
|
||||||
path,
|
|
||||||
headers={"authorization": f"Bearer {token}"},
|
|
||||||
json=json,
|
|
||||||
)
|
|
||||||
if exist_ok and r.status_code == 409:
|
|
||||||
return ""
|
|
||||||
r.raise_for_status()
|
|
||||||
d = r.json()
|
|
||||||
return d["url"]
|
|
||||||
|
|
||||||
def delete_repo(self, token: str, name: str, organization: Optional[str] = None):
|
|
||||||
"""
|
|
||||||
HuggingFace git-based system, used for models.
|
|
||||||
|
|
||||||
Call HF API to delete a whole repo.
|
|
||||||
|
|
||||||
CAUTION(this is irreversible).
|
|
||||||
"""
|
|
||||||
path = f"{self.endpoint}/api/repos/delete"
|
|
||||||
r = requests.delete(
|
|
||||||
path,
|
|
||||||
headers={"authorization": f"Bearer {token}"},
|
|
||||||
json={"name": name, "organization": organization},
|
|
||||||
)
|
|
||||||
r.raise_for_status()
|
|
||||||
|
|
||||||
|
|
||||||
class TqdmProgressFileReader:
|
|
||||||
"""
|
|
||||||
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) and override `f.read()` so as to display a
|
|
||||||
tqdm progress bar.
|
|
||||||
|
|
||||||
see github.com/huggingface/transformers/pull/2078#discussion_r354739608 for implementation details.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, f: io.BufferedReader):
|
|
||||||
self.f = f
|
|
||||||
self.total_size = os.fstat(f.fileno()).st_size
|
|
||||||
self.pbar = tqdm(total=self.total_size, leave=False)
|
|
||||||
self.read = f.read
|
|
||||||
f.read = self._read
|
|
||||||
|
|
||||||
def _read(self, n=-1):
|
|
||||||
self.pbar.update(n)
|
|
||||||
return self.read(n)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.pbar.close()
|
|
||||||
|
|
||||||
|
|
||||||
class HfFolder:
|
|
||||||
path_token = expanduser("~/.huggingface/token")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def save_token(cls, token):
|
|
||||||
"""
|
|
||||||
Save token, creating folder as needed.
|
|
||||||
"""
|
|
||||||
os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
|
|
||||||
with open(cls.path_token, "w+") as f:
|
|
||||||
f.write(token)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_token(cls):
|
|
||||||
"""
|
|
||||||
Get token or None if not existent.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with open(cls.path_token, "r") as f:
|
|
||||||
return f.read()
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def delete_token(cls):
|
|
||||||
"""
|
|
||||||
Delete token. Do not fail if token does not exist.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
os.remove(cls.path_token)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
@@ -27,8 +27,8 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from huggingface_hub.hf_api import HfApi
|
||||||
from transformers import MarianConfig, MarianMTModel, MarianTokenizer
|
from transformers import MarianConfig, MarianMTModel, MarianTokenizer
|
||||||
from transformers.hf_api import HfApi
|
|
||||||
|
|
||||||
|
|
||||||
def remove_suffix(text: str, suffix: str):
|
def remove_suffix(text: str, suffix: str):
|
||||||
@@ -65,7 +65,7 @@ def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]:
|
|||||||
"""Find models that can accept src_lang as input and return tgt_lang as output."""
|
"""Find models that can accept src_lang as input and return tgt_lang as output."""
|
||||||
prefix = "Helsinki-NLP/opus-mt-"
|
prefix = "Helsinki-NLP/opus-mt-"
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
model_list = api.model_list()
|
model_list = api.list_models()
|
||||||
model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")]
|
model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")]
|
||||||
src_and_targ = [
|
src_and_targ = [
|
||||||
remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m
|
remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m
|
||||||
|
|||||||
@@ -1,174 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from requests.exceptions import HTTPError
|
|
||||||
from transformers.hf_api import HfApi, HfFolder, ModelInfo, RepoObj
|
|
||||||
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test, require_git_lfs
|
|
||||||
|
|
||||||
|
|
||||||
ENDPOINT_STAGING_BASIC_AUTH = f"https://{USER}:{PASS}@moon-staging.huggingface.co"
|
|
||||||
|
|
||||||
REPO_NAME = f"my-model-{int(time.time())}"
|
|
||||||
REPO_NAME_LARGE_FILE = f"my-model-largefiles-{int(time.time())}"
|
|
||||||
WORKING_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/working_repo")
|
|
||||||
LARGE_FILE_14MB = "https://cdn-media.huggingface.co/lfs-largefiles/progit.epub"
|
|
||||||
LARGE_FILE_18MB = "https://cdn-media.huggingface.co/lfs-largefiles/progit.pdf"
|
|
||||||
|
|
||||||
|
|
||||||
class HfApiCommonTest(unittest.TestCase):
|
|
||||||
_api = HfApi(endpoint=ENDPOINT_STAGING)
|
|
||||||
|
|
||||||
|
|
||||||
class HfApiLoginTest(HfApiCommonTest):
|
|
||||||
def test_login_invalid(self):
|
|
||||||
with self.assertRaises(HTTPError):
|
|
||||||
self._api.login(username=USER, password="fake")
|
|
||||||
|
|
||||||
def test_login_valid(self):
|
|
||||||
token = self._api.login(username=USER, password=PASS)
|
|
||||||
self.assertIsInstance(token, str)
|
|
||||||
|
|
||||||
|
|
||||||
class HfApiEndpointsTest(HfApiCommonTest):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
"""
|
|
||||||
Share this valid token in all tests below.
|
|
||||||
"""
|
|
||||||
cls._token = cls._api.login(username=USER, password=PASS)
|
|
||||||
|
|
||||||
def test_whoami(self):
|
|
||||||
user, orgs = self._api.whoami(token=self._token)
|
|
||||||
self.assertEqual(user, USER)
|
|
||||||
self.assertIsInstance(orgs, list)
|
|
||||||
|
|
||||||
def test_list_repos_objs(self):
|
|
||||||
objs = self._api.list_repos_objs(token=self._token)
|
|
||||||
self.assertIsInstance(objs, list)
|
|
||||||
if len(objs) > 0:
|
|
||||||
o = objs[-1]
|
|
||||||
self.assertIsInstance(o, RepoObj)
|
|
||||||
|
|
||||||
def test_create_and_delete_repo(self):
|
|
||||||
self._api.create_repo(token=self._token, name=REPO_NAME)
|
|
||||||
self._api.delete_repo(token=self._token, name=REPO_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
class HfApiPublicTest(unittest.TestCase):
|
|
||||||
def test_staging_model_list(self):
|
|
||||||
_api = HfApi(endpoint=ENDPOINT_STAGING)
|
|
||||||
_ = _api.model_list()
|
|
||||||
|
|
||||||
def test_model_list(self):
|
|
||||||
_api = HfApi()
|
|
||||||
models = _api.model_list()
|
|
||||||
self.assertGreater(len(models), 100)
|
|
||||||
self.assertIsInstance(models[0], ModelInfo)
|
|
||||||
|
|
||||||
|
|
||||||
class HfFolderTest(unittest.TestCase):
|
|
||||||
def test_token_workflow(self):
|
|
||||||
"""
|
|
||||||
Test the whole token save/get/delete workflow,
|
|
||||||
with the desired behavior with respect to non-existent tokens.
|
|
||||||
"""
|
|
||||||
token = f"token-{int(time.time())}"
|
|
||||||
HfFolder.save_token(token)
|
|
||||||
self.assertEqual(HfFolder.get_token(), token)
|
|
||||||
HfFolder.delete_token()
|
|
||||||
HfFolder.delete_token()
|
|
||||||
# ^^ not an error, we test that the
|
|
||||||
# second call does not fail.
|
|
||||||
self.assertEqual(HfFolder.get_token(), None)
|
|
||||||
|
|
||||||
|
|
||||||
@require_git_lfs
|
|
||||||
@is_staging_test
|
|
||||||
class HfLargefilesTest(HfApiCommonTest):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
"""
|
|
||||||
Share this valid token in all tests below.
|
|
||||||
"""
|
|
||||||
cls._token = cls._api.login(username=USER, password=PASS)
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
try:
|
|
||||||
shutil.rmtree(WORKING_REPO_DIR)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self._api.delete_repo(token=self._token, name=REPO_NAME_LARGE_FILE)
|
|
||||||
|
|
||||||
def setup_local_clone(self, REMOTE_URL):
|
|
||||||
REMOTE_URL_AUTH = REMOTE_URL.replace(ENDPOINT_STAGING, ENDPOINT_STAGING_BASIC_AUTH)
|
|
||||||
subprocess.run(["git", "clone", REMOTE_URL_AUTH, WORKING_REPO_DIR], check=True, capture_output=True)
|
|
||||||
subprocess.run(["git", "lfs", "track", "*.pdf"], check=True, cwd=WORKING_REPO_DIR)
|
|
||||||
subprocess.run(["git", "lfs", "track", "*.epub"], check=True, cwd=WORKING_REPO_DIR)
|
|
||||||
|
|
||||||
def test_end_to_end_thresh_6M(self):
|
|
||||||
REMOTE_URL = self._api.create_repo(
|
|
||||||
token=self._token, name=REPO_NAME_LARGE_FILE, lfsmultipartthresh=6 * 10 ** 6
|
|
||||||
)
|
|
||||||
self.setup_local_clone(REMOTE_URL)
|
|
||||||
|
|
||||||
subprocess.run(["wget", LARGE_FILE_18MB], check=True, capture_output=True, cwd=WORKING_REPO_DIR)
|
|
||||||
subprocess.run(["git", "add", "*"], check=True, cwd=WORKING_REPO_DIR)
|
|
||||||
subprocess.run(["git", "commit", "-m", "commit message"], check=True, cwd=WORKING_REPO_DIR)
|
|
||||||
|
|
||||||
# This will fail as we haven't set up our custom transfer agent yet.
|
|
||||||
failed_process = subprocess.run(["git", "push"], capture_output=True, cwd=WORKING_REPO_DIR)
|
|
||||||
self.assertEqual(failed_process.returncode, 1)
|
|
||||||
self.assertIn("transformers-cli lfs-enable-largefiles", failed_process.stderr.decode())
|
|
||||||
# ^ Instructions on how to fix this are included in the error message.
|
|
||||||
|
|
||||||
subprocess.run(["transformers-cli", "lfs-enable-largefiles", WORKING_REPO_DIR], check=True)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
subprocess.run(["git", "push"], check=True, cwd=WORKING_REPO_DIR)
|
|
||||||
print("took", time.time() - start_time)
|
|
||||||
|
|
||||||
# To be 100% sure, let's download the resolved file
|
|
||||||
pdf_url = f"{REMOTE_URL}/resolve/main/progit.pdf"
|
|
||||||
DEST_FILENAME = "uploaded.pdf"
|
|
||||||
subprocess.run(["wget", pdf_url, "-O", DEST_FILENAME], check=True, capture_output=True, cwd=WORKING_REPO_DIR)
|
|
||||||
dest_filesize = os.stat(os.path.join(WORKING_REPO_DIR, DEST_FILENAME)).st_size
|
|
||||||
self.assertEqual(dest_filesize, 18685041)
|
|
||||||
|
|
||||||
def test_end_to_end_thresh_16M(self):
|
|
||||||
# Here we'll push one multipart and one non-multipart file in the same commit, and see what happens
|
|
||||||
REMOTE_URL = self._api.create_repo(
|
|
||||||
token=self._token, name=REPO_NAME_LARGE_FILE, lfsmultipartthresh=16 * 10 ** 6
|
|
||||||
)
|
|
||||||
self.setup_local_clone(REMOTE_URL)
|
|
||||||
|
|
||||||
subprocess.run(["wget", LARGE_FILE_18MB], check=True, capture_output=True, cwd=WORKING_REPO_DIR)
|
|
||||||
subprocess.run(["wget", LARGE_FILE_14MB], check=True, capture_output=True, cwd=WORKING_REPO_DIR)
|
|
||||||
subprocess.run(["git", "add", "*"], check=True, cwd=WORKING_REPO_DIR)
|
|
||||||
subprocess.run(["git", "commit", "-m", "both files in same commit"], check=True, cwd=WORKING_REPO_DIR)
|
|
||||||
|
|
||||||
subprocess.run(["transformers-cli", "lfs-enable-largefiles", WORKING_REPO_DIR], check=True)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
subprocess.run(["git", "push"], check=True, cwd=WORKING_REPO_DIR)
|
|
||||||
print("took", time.time() - start_time)
|
|
||||||
@@ -17,9 +17,9 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from huggingface_hub.hf_api import HfApi
|
||||||
from transformers import MarianConfig, is_torch_available
|
from transformers import MarianConfig, is_torch_available
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
from transformers.hf_api import HfApi
|
|
||||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -296,7 +296,7 @@ class ModelManagementTests(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_model_names(self):
|
def test_model_names(self):
|
||||||
model_list = HfApi().model_list()
|
model_list = HfApi().list_models()
|
||||||
model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)]
|
model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)]
|
||||||
bad_model_ids = [mid for mid in model_ids if "+" in model_ids]
|
bad_model_ids = [mid for mid in model_ids if "+" in model_ids]
|
||||||
self.assertListEqual([], bad_model_ids)
|
self.assertListEqual([], bad_model_ids)
|
||||||
|
|||||||
@@ -412,6 +412,8 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
|
|||||||
|
|
||||||
# Remove duplicates
|
# Remove duplicates
|
||||||
test_files_to_run = sorted(list(set(test_files_to_run)))
|
test_files_to_run = sorted(list(set(test_files_to_run)))
|
||||||
|
# Make sure we did not end up with a test file that was removed
|
||||||
|
test_files_to_run = [f for f in test_files_to_run if os.path.isfile(f) or os.path.isdir(f)]
|
||||||
if filters is not None:
|
if filters is not None:
|
||||||
for filter in filters:
|
for filter in filters:
|
||||||
test_files_to_run = [f for f in test_files_to_run if f.startswith(filter)]
|
test_files_to_run = [f for f in test_files_to_run if f.startswith(filter)]
|
||||||
|
|||||||
Reference in New Issue
Block a user