Support for private models from huggingface.co (#9141)

* minor wording tweaks

* Create private model repo + exist_ok flag

* file_utils: `use_auth_token`

* Update src/transformers/file_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Propagate doc from @sgugger

Co-Authored-By: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Julien Chaumond
2020-12-16 16:09:57 +01:00
committed by GitHub
parent c69d19faa8
commit fb650df859
8 changed files with 77 additions and 9 deletions

View File

@@ -206,7 +206,7 @@ class HfApi:
def model_list(self) -> List[ModelInfo]:
"""
Get the public list of all the models on huggingface, including the community models
Get the public list of all the models on huggingface.co
"""
path = "{}/api/models".format(self.endpoint)
r = requests.get(path)
@@ -228,7 +228,13 @@ class HfApi:
return [RepoObj(**x) for x in d]
def create_repo(
self, token: str, name: str, organization: Optional[str] = None, lfsmultipartthresh: Optional[int] = None
self,
token: str,
name: str,
organization: Optional[str] = None,
private: Optional[bool] = None,
exist_ok=False,
lfsmultipartthresh: Optional[int] = None,
) -> str:
"""
HuggingFace git-based system, used for models.
@@ -236,10 +242,14 @@ class HfApi:
Call HF API to create a whole repo.
Params:
private: Whether the model repo should be private (requires a paid huggingface.co account)
exist_ok: Do not raise an error if repo already exists
lfsmultipartthresh: Optional: internal param for testing purposes.
"""
path = "{}/api/repos/create".format(self.endpoint)
json = {"name": name, "organization": organization}
json = {"name": name, "organization": organization, "private": private}
if lfsmultipartthresh is not None:
json["lfsmultipartthresh"] = lfsmultipartthresh
r = requests.post(
@@ -247,6 +257,8 @@ class HfApi:
headers={"authorization": "Bearer {}".format(token)},
json=json,
)
if exist_ok and r.status_code == 409:
return ""
r.raise_for_status()
d = r.json()
return d["url"]