Merge pull request #3132 from huggingface/hf_api_model_list
[hf_api] Get the public list of all the models on huggingface
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
import io
|
||||
import os
|
||||
from os.path import expanduser
|
||||
from typing import List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
@@ -27,6 +27,10 @@ ENDPOINT = "https://huggingface.co"
|
||||
|
||||
|
||||
class S3Obj:
|
||||
"""
|
||||
Data structure that represents a file belonging to the current user.
|
||||
"""
|
||||
|
||||
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
|
||||
self.filename = filename
|
||||
self.LastModified = LastModified
|
||||
@@ -41,6 +45,50 @@ class PresignedUrl:
|
||||
self.type = type # mime-type to send to S3.
|
||||
|
||||
|
||||
class S3Object:
|
||||
"""
|
||||
Data structure that represents a public file accessible on our S3.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str, # S3 object key
|
||||
etag: str,
|
||||
lastModified: str,
|
||||
size: int,
|
||||
rfilename: str, # filename relative to config.json
|
||||
**kwargs
|
||||
):
|
||||
self.key = key
|
||||
self.etag = etag
|
||||
self.lastModified = lastModified
|
||||
self.size = size
|
||||
self.rfilename = rfilename
|
||||
|
||||
|
||||
class ModelInfo:
|
||||
"""
|
||||
Info about a public model accessible from our S3.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
modelId: str, # id of model
|
||||
key: str, # S3 object key of config.json
|
||||
author: Optional[str] = None,
|
||||
downloads: Optional[int] = None,
|
||||
tags: List[str] = [],
|
||||
siblings: List[Dict] = [], # list of files that constitute the model
|
||||
**kwargs
|
||||
):
|
||||
self.modelId = modelId
|
||||
self.key = key
|
||||
self.author = author
|
||||
self.downloads = downloads
|
||||
self.tags = tags
|
||||
self.siblings = [S3Object(**x) for x in siblings]
|
||||
|
||||
|
||||
class HfApi:
|
||||
def __init__(self, endpoint=None):
|
||||
self.endpoint = endpoint if endpoint is not None else ENDPOINT
|
||||
@@ -129,6 +177,16 @@ class HfApi:
|
||||
r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
|
||||
r.raise_for_status()
|
||||
|
||||
def model_list(self) -> List[ModelInfo]:
|
||||
"""
|
||||
Get the public list of all the models on huggingface, including the community models
|
||||
"""
|
||||
path = "{}/api/models".format(self.endpoint)
|
||||
r = requests.get(path)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
return [ModelInfo(**x) for x in d]
|
||||
|
||||
|
||||
class TqdmProgressFileReader:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user