[hf_api] Get the public list of all the models on huggingface

This commit is contained in:
Julien Chaumond
2020-03-04 23:33:09 -05:00
parent ff9e79ba3a
commit f564f93c84
2 changed files with 74 additions and 3 deletions

View File

@@ -17,7 +17,7 @@
import io import io
import os import os
from os.path import expanduser from os.path import expanduser
from typing import List from typing import Dict, List, Optional
import requests import requests
from tqdm import tqdm from tqdm import tqdm
@@ -27,6 +27,10 @@ ENDPOINT = "https://huggingface.co"
class S3Obj: class S3Obj:
"""
Data structure that represents a file belonging to the current user.
"""
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs): def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
self.filename = filename self.filename = filename
self.LastModified = LastModified self.LastModified = LastModified
@@ -41,6 +45,50 @@ class PresignedUrl:
self.type = type # mime-type to send to S3. self.type = type # mime-type to send to S3.
class S3Object:
"""
Data structure that represents a public file accessible on our S3.
"""
def __init__(
self,
key: str, # S3 object key
etag: str,
lastModified: str,
size: int,
rfilename: str, # filename relative to config.json
**kwargs
):
self.key = key
self.etag = etag
self.lastModified = lastModified
self.size = size
self.rfilename = rfilename
class ModelInfo:
"""
Info about a public model accessible from our S3.
"""
def __init__(
self,
modelId: str, # id of model
key: str, # S3 object key of config.json
author: Optional[str] = None,
downloads: Optional[int] = None,
tags: List[str] = [],
siblings: List[Dict] = [],
**kwargs
):
self.modelId = modelId
self.key = key
self.author = author
self.downloads = downloads
self.tags = tags
self.siblings = [S3Object(**x) for x in siblings]
class HfApi: class HfApi:
def __init__(self, endpoint=None): def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else ENDPOINT self.endpoint = endpoint if endpoint is not None else ENDPOINT
@@ -129,6 +177,16 @@ class HfApi:
r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename}) r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
r.raise_for_status() r.raise_for_status()
def model_list(self) -> List[ModelInfo]:
"""
Get the public list of all the models on huggingface, including the community models
"""
path = "{}/api/models".format(self.endpoint)
r = requests.get(path)
r.raise_for_status()
d = r.json()
return [ModelInfo(**x) for x in d]
class TqdmProgressFileReader: class TqdmProgressFileReader:
""" """

View File

@@ -21,7 +21,7 @@ import unittest
import requests import requests
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers.hf_api import HfApi, HfFolder, PresignedUrl, S3Obj from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, S3Obj
USER = "__DUMMY_TRANSFORMERS_USER__" USER = "__DUMMY_TRANSFORMERS_USER__"
@@ -36,10 +36,11 @@ FILES = [
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"), os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
), ),
] ]
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
class HfApiCommonTest(unittest.TestCase): class HfApiCommonTest(unittest.TestCase):
_api = HfApi(endpoint="https://moon-staging.huggingface.co") _api = HfApi(endpoint=ENDPOINT_STAGING)
class HfApiLoginTest(HfApiCommonTest): class HfApiLoginTest(HfApiCommonTest):
@@ -92,6 +93,18 @@ class HfApiEndpointsTest(HfApiCommonTest):
self.assertIsInstance(o, S3Obj) self.assertIsInstance(o, S3Obj)
class HfApiPublicTest(unittest.TestCase):
def test_staging_model_list(self):
_api = HfApi(endpoint=ENDPOINT_STAGING)
_ = _api.model_list()
def test_model_list(self):
_api = HfApi()
models = _api.model_list()
self.assertGreater(len(models), 100)
self.assertIsInstance(models[0], ModelInfo)
class HfFolderTest(unittest.TestCase): class HfFolderTest(unittest.TestCase):
def test_token_workflow(self): def test_token_workflow(self):
""" """