[model upload] Support for organizations
This commit is contained in:
@@ -26,13 +26,16 @@ class UserCommands(BaseTransformersCLICommand):
|
|||||||
s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.")
|
s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.")
|
||||||
s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
|
s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
|
||||||
ls_parser = s3_subparsers.add_parser("ls")
|
ls_parser = s3_subparsers.add_parser("ls")
|
||||||
|
ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
||||||
ls_parser.set_defaults(func=lambda args: ListObjsCommand(args))
|
ls_parser.set_defaults(func=lambda args: ListObjsCommand(args))
|
||||||
rm_parser = s3_subparsers.add_parser("rm")
|
rm_parser = s3_subparsers.add_parser("rm")
|
||||||
rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
|
rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
|
||||||
|
rm_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
||||||
rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
|
rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
|
||||||
# upload
|
# upload
|
||||||
upload_parser = parser.add_parser("upload")
|
upload_parser = parser.add_parser("upload")
|
||||||
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
|
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
|
||||||
|
upload_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
||||||
upload_parser.add_argument(
|
upload_parser.add_argument(
|
||||||
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
|
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
|
||||||
)
|
)
|
||||||
@@ -91,8 +94,10 @@ class WhoamiCommand(BaseUserCommand):
|
|||||||
print("Not logged in")
|
print("Not logged in")
|
||||||
exit()
|
exit()
|
||||||
try:
|
try:
|
||||||
user = self._api.whoami(token)
|
user, orgs = self._api.whoami(token)
|
||||||
print(user)
|
print(user)
|
||||||
|
if orgs:
|
||||||
|
print(ANSI.bold("orgs: "), ",".join(orgs))
|
||||||
except HTTPError as e:
|
except HTTPError as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
@@ -130,7 +135,7 @@ class ListObjsCommand(BaseUserCommand):
|
|||||||
print("Not logged in")
|
print("Not logged in")
|
||||||
exit(1)
|
exit(1)
|
||||||
try:
|
try:
|
||||||
objs = self._api.list_objs(token)
|
objs = self._api.list_objs(token, organization=self.args.organization)
|
||||||
except HTTPError as e:
|
except HTTPError as e:
|
||||||
print(e)
|
print(e)
|
||||||
exit(1)
|
exit(1)
|
||||||
@@ -148,7 +153,7 @@ class DeleteObjCommand(BaseUserCommand):
|
|||||||
print("Not logged in")
|
print("Not logged in")
|
||||||
exit(1)
|
exit(1)
|
||||||
try:
|
try:
|
||||||
self._api.delete_obj(token, filename=self.args.filename)
|
self._api.delete_obj(token, filename=self.args.filename, organization=self.args.organization)
|
||||||
except HTTPError as e:
|
except HTTPError as e:
|
||||||
print(e)
|
print(e)
|
||||||
exit(1)
|
exit(1)
|
||||||
@@ -195,8 +200,15 @@ class UploadCommand(BaseUserCommand):
|
|||||||
)
|
)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
user, _ = self._api.whoami(token)
|
||||||
|
namespace = self.args.organization if self.args.organization is not None else user
|
||||||
|
|
||||||
for filepath, filename in files:
|
for filepath, filename in files:
|
||||||
print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename)))
|
print(
|
||||||
|
"About to upload file {} to S3 under filename {} and namespace {}".format(
|
||||||
|
ANSI.bold(filepath), ANSI.bold(filename), ANSI.bold(namespace)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
choice = input("Proceed? [Y/n] ").lower()
|
choice = input("Proceed? [Y/n] ").lower()
|
||||||
if not (choice == "" or choice == "y" or choice == "yes"):
|
if not (choice == "" or choice == "y" or choice == "yes"):
|
||||||
@@ -204,6 +216,8 @@ class UploadCommand(BaseUserCommand):
|
|||||||
exit()
|
exit()
|
||||||
print(ANSI.bold("Uploading... This might take a while if files are large"))
|
print(ANSI.bold("Uploading... This might take a while if files are large"))
|
||||||
for filepath, filename in files:
|
for filepath, filename in files:
|
||||||
access_url = self._api.presign_and_upload(token=token, filename=filename, filepath=filepath)
|
access_url = self._api.presign_and_upload(
|
||||||
|
token=token, filename=filename, filepath=filepath, organization=self.args.organization
|
||||||
|
)
|
||||||
print("Your file now lives at:")
|
print("Your file now lives at:")
|
||||||
print(access_url)
|
print(access_url)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
from os.path import expanduser
|
from os.path import expanduser
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -109,7 +109,7 @@ class HfApi:
|
|||||||
d = r.json()
|
d = r.json()
|
||||||
return d["token"]
|
return d["token"]
|
||||||
|
|
||||||
def whoami(self, token: str) -> str:
|
def whoami(self, token: str) -> Tuple[str, List[str]]:
|
||||||
"""
|
"""
|
||||||
Call HF API to know "whoami"
|
Call HF API to know "whoami"
|
||||||
"""
|
"""
|
||||||
@@ -117,7 +117,7 @@ class HfApi:
|
|||||||
r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
|
r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
d = r.json()
|
d = r.json()
|
||||||
return d["user"]
|
return d["user"], d["orgs"]
|
||||||
|
|
||||||
def logout(self, token: str) -> None:
|
def logout(self, token: str) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -127,24 +127,28 @@ class HfApi:
|
|||||||
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
|
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
def presign(self, token: str, filename: str) -> PresignedUrl:
|
def presign(self, token: str, filename: str, organization: Optional[str] = None) -> PresignedUrl:
|
||||||
"""
|
"""
|
||||||
Call HF API to get a presigned url to upload `filename` to S3.
|
Call HF API to get a presigned url to upload `filename` to S3.
|
||||||
"""
|
"""
|
||||||
path = "{}/api/presign".format(self.endpoint)
|
path = "{}/api/presign".format(self.endpoint)
|
||||||
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
|
r = requests.post(
|
||||||
|
path,
|
||||||
|
headers={"authorization": "Bearer {}".format(token)},
|
||||||
|
json={"filename": filename, "organization": organization},
|
||||||
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
d = r.json()
|
d = r.json()
|
||||||
return PresignedUrl(**d)
|
return PresignedUrl(**d)
|
||||||
|
|
||||||
def presign_and_upload(self, token: str, filename: str, filepath: str) -> str:
|
def presign_and_upload(self, token: str, filename: str, filepath: str, organization: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Get a presigned url, then upload file to S3.
|
Get a presigned url, then upload file to S3.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
url: Read-only url for the stored file on S3.
|
url: Read-only url for the stored file on S3.
|
||||||
"""
|
"""
|
||||||
urls = self.presign(token, filename=filename)
|
urls = self.presign(token, filename=filename, organization=organization)
|
||||||
# streaming upload:
|
# streaming upload:
|
||||||
# https://2.python-requests.org/en/master/user/advanced/#streaming-uploads
|
# https://2.python-requests.org/en/master/user/advanced/#streaming-uploads
|
||||||
#
|
#
|
||||||
@@ -159,22 +163,27 @@ class HfApi:
|
|||||||
pf.close()
|
pf.close()
|
||||||
return urls.access
|
return urls.access
|
||||||
|
|
||||||
def list_objs(self, token: str) -> List[S3Obj]:
|
def list_objs(self, token: str, organization: Optional[str] = None) -> List[S3Obj]:
|
||||||
"""
|
"""
|
||||||
Call HF API to list all stored files for user.
|
Call HF API to list all stored files for user (or one of their organizations).
|
||||||
"""
|
"""
|
||||||
path = "{}/api/listObjs".format(self.endpoint)
|
path = "{}/api/listObjs".format(self.endpoint)
|
||||||
r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
|
params = {"organization": organization} if organization is not None else None
|
||||||
|
r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)})
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
d = r.json()
|
d = r.json()
|
||||||
return [S3Obj(**x) for x in d]
|
return [S3Obj(**x) for x in d]
|
||||||
|
|
||||||
def delete_obj(self, token: str, filename: str):
|
def delete_obj(self, token: str, filename: str, organization: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Call HF API to delete a file stored by user
|
Call HF API to delete a file stored by user
|
||||||
"""
|
"""
|
||||||
path = "{}/api/deleteObj".format(self.endpoint)
|
path = "{}/api/deleteObj".format(self.endpoint)
|
||||||
r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
|
r = requests.delete(
|
||||||
|
path,
|
||||||
|
headers={"authorization": "Bearer {}".format(token)},
|
||||||
|
json={"filename": filename, "organization": organization},
|
||||||
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
def model_list(self) -> List[ModelInfo]:
|
def model_list(self) -> List[ModelInfo]:
|
||||||
|
|||||||
@@ -67,8 +67,17 @@ class HfApiEndpointsTest(HfApiCommonTest):
|
|||||||
cls._api.delete_obj(token=cls._token, filename=FILE_KEY)
|
cls._api.delete_obj(token=cls._token, filename=FILE_KEY)
|
||||||
|
|
||||||
def test_whoami(self):
|
def test_whoami(self):
|
||||||
user = self._api.whoami(token=self._token)
|
user, orgs = self._api.whoami(token=self._token)
|
||||||
self.assertEqual(user, USER)
|
self.assertEqual(user, USER)
|
||||||
|
self.assertIsInstance(orgs, list)
|
||||||
|
|
||||||
|
def test_presign_invalid_org(self):
|
||||||
|
with self.assertRaises(HTTPError):
|
||||||
|
_ = self._api.presign(token=self._token, filename="fake_org.txt", organization="fake")
|
||||||
|
|
||||||
|
def test_presign_valid_org(self):
|
||||||
|
urls = self._api.presign(token=self._token, filename="valid_org.txt", organization="valid_org")
|
||||||
|
self.assertIsInstance(urls, PresignedUrl)
|
||||||
|
|
||||||
def test_presign(self):
|
def test_presign(self):
|
||||||
for FILE_KEY, FILE_PATH in FILES:
|
for FILE_KEY, FILE_PATH in FILES:
|
||||||
|
|||||||
Reference in New Issue
Block a user