Merge pull request #3132 from huggingface/hf_api_model_list
[hf_api] Get the public list of all the models on huggingface
This commit is contained in:
@@ -17,7 +17,7 @@
|
|||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
from os.path import expanduser
|
from os.path import expanduser
|
||||||
from typing import List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -27,6 +27,10 @@ ENDPOINT = "https://huggingface.co"
|
|||||||
|
|
||||||
|
|
||||||
class S3Obj:
|
class S3Obj:
|
||||||
|
"""
|
||||||
|
Data structure that represents a file belonging to the current user.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
|
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
self.LastModified = LastModified
|
self.LastModified = LastModified
|
||||||
@@ -41,6 +45,50 @@ class PresignedUrl:
|
|||||||
self.type = type # mime-type to send to S3.
|
self.type = type # mime-type to send to S3.
|
||||||
|
|
||||||
|
|
||||||
|
class S3Object:
|
||||||
|
"""
|
||||||
|
Data structure that represents a public file accessible on our S3.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key: str, # S3 object key
|
||||||
|
etag: str,
|
||||||
|
lastModified: str,
|
||||||
|
size: int,
|
||||||
|
rfilename: str, # filename relative to config.json
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.key = key
|
||||||
|
self.etag = etag
|
||||||
|
self.lastModified = lastModified
|
||||||
|
self.size = size
|
||||||
|
self.rfilename = rfilename
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInfo:
|
||||||
|
"""
|
||||||
|
Info about a public model accessible from our S3.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
modelId: str, # id of model
|
||||||
|
key: str, # S3 object key of config.json
|
||||||
|
author: Optional[str] = None,
|
||||||
|
downloads: Optional[int] = None,
|
||||||
|
tags: List[str] = [],
|
||||||
|
siblings: List[Dict] = [], # list of files that constitute the model
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.modelId = modelId
|
||||||
|
self.key = key
|
||||||
|
self.author = author
|
||||||
|
self.downloads = downloads
|
||||||
|
self.tags = tags
|
||||||
|
self.siblings = [S3Object(**x) for x in siblings]
|
||||||
|
|
||||||
|
|
||||||
class HfApi:
|
class HfApi:
|
||||||
def __init__(self, endpoint=None):
|
def __init__(self, endpoint=None):
|
||||||
self.endpoint = endpoint if endpoint is not None else ENDPOINT
|
self.endpoint = endpoint if endpoint is not None else ENDPOINT
|
||||||
@@ -129,6 +177,16 @@ class HfApi:
|
|||||||
r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
|
r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
|
def model_list(self) -> List[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Get the public list of all the models on huggingface, including the community models
|
||||||
|
"""
|
||||||
|
path = "{}/api/models".format(self.endpoint)
|
||||||
|
r = requests.get(path)
|
||||||
|
r.raise_for_status()
|
||||||
|
d = r.json()
|
||||||
|
return [ModelInfo(**x) for x in d]
|
||||||
|
|
||||||
|
|
||||||
class TqdmProgressFileReader:
|
class TqdmProgressFileReader:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import unittest
|
|||||||
import requests
|
import requests
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
from transformers.hf_api import HfApi, HfFolder, PresignedUrl, S3Obj
|
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, S3Obj
|
||||||
|
|
||||||
|
|
||||||
USER = "__DUMMY_TRANSFORMERS_USER__"
|
USER = "__DUMMY_TRANSFORMERS_USER__"
|
||||||
@@ -36,10 +36,11 @@ FILES = [
|
|||||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
|
||||||
|
|
||||||
|
|
||||||
class HfApiCommonTest(unittest.TestCase):
|
class HfApiCommonTest(unittest.TestCase):
|
||||||
_api = HfApi(endpoint="https://moon-staging.huggingface.co")
|
_api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||||
|
|
||||||
|
|
||||||
class HfApiLoginTest(HfApiCommonTest):
|
class HfApiLoginTest(HfApiCommonTest):
|
||||||
@@ -92,6 +93,18 @@ class HfApiEndpointsTest(HfApiCommonTest):
|
|||||||
self.assertIsInstance(o, S3Obj)
|
self.assertIsInstance(o, S3Obj)
|
||||||
|
|
||||||
|
|
||||||
|
class HfApiPublicTest(unittest.TestCase):
|
||||||
|
def test_staging_model_list(self):
|
||||||
|
_api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||||
|
_ = _api.model_list()
|
||||||
|
|
||||||
|
def test_model_list(self):
|
||||||
|
_api = HfApi()
|
||||||
|
models = _api.model_list()
|
||||||
|
self.assertGreater(len(models), 100)
|
||||||
|
self.assertIsInstance(models[0], ModelInfo)
|
||||||
|
|
||||||
|
|
||||||
class HfFolderTest(unittest.TestCase):
|
class HfFolderTest(unittest.TestCase):
|
||||||
def test_token_workflow(self):
|
def test_token_workflow(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user