Merge pull request #3132 from huggingface/hf_api_model_list

[hf_api] Get the public list of all the models on huggingface
This commit is contained in:
Thomas Wolf
2020-03-06 13:05:52 +01:00
committed by GitHub
2 changed files with 74 additions and 3 deletions

View File

@@ -17,7 +17,7 @@
import io
import os
from os.path import expanduser
from typing import List
from typing import Dict, List, Optional
import requests
from tqdm import tqdm
@@ -27,6 +27,10 @@ ENDPOINT = "https://huggingface.co"
class S3Obj:
"""
Data structure that represents a file belonging to the current user.
"""
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
self.filename = filename
self.LastModified = LastModified
@@ -41,6 +45,50 @@ class PresignedUrl:
self.type = type # mime-type to send to S3.
class S3Object:
"""
Data structure that represents a public file accessible on our S3.
"""
def __init__(
self,
key: str, # S3 object key
etag: str,
lastModified: str,
size: int,
rfilename: str, # filename relative to config.json
**kwargs
):
self.key = key
self.etag = etag
self.lastModified = lastModified
self.size = size
self.rfilename = rfilename
class ModelInfo:
"""
Info about a public model accessible from our S3.
"""
def __init__(
self,
modelId: str, # id of model
key: str, # S3 object key of config.json
author: Optional[str] = None,
downloads: Optional[int] = None,
tags: List[str] = [],
siblings: List[Dict] = [], # list of files that constitute the model
**kwargs
):
self.modelId = modelId
self.key = key
self.author = author
self.downloads = downloads
self.tags = tags
self.siblings = [S3Object(**x) for x in siblings]
class HfApi:
def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else ENDPOINT
@@ -129,6 +177,16 @@ class HfApi:
r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
r.raise_for_status()
def model_list(self) -> List[ModelInfo]:
"""
Get the public list of all the models on huggingface, including the community models
"""
path = "{}/api/models".format(self.endpoint)
r = requests.get(path)
r.raise_for_status()
d = r.json()
return [ModelInfo(**x) for x in d]
class TqdmProgressFileReader:
"""