From cbf8f5d32bdbfaef2d31daba0bfdf14fe2640d0b Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 9 Mar 2020 17:29:49 -0400 Subject: [PATCH] [model upload] Support for organizations --- src/transformers/commands/user.py | 24 +++++++++++++++++----- src/transformers/hf_api.py | 33 ++++++++++++++++++++----------- tests/test_hf_api.py | 11 ++++++++++- 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/src/transformers/commands/user.py b/src/transformers/commands/user.py index 47c7860114..89a14e7477 100644 --- a/src/transformers/commands/user.py +++ b/src/transformers/commands/user.py @@ -26,13 +26,16 @@ class UserCommands(BaseTransformersCLICommand): s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.") s3_subparsers = s3_parser.add_subparsers(help="s3 related commands") ls_parser = s3_subparsers.add_parser("ls") + ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.") ls_parser.set_defaults(func=lambda args: ListObjsCommand(args)) rm_parser = s3_subparsers.add_parser("rm") rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.") + rm_parser.add_argument("--organization", type=str, help="Optional: organization namespace.") rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args)) # upload upload_parser = parser.add_parser("upload") upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.") + upload_parser.add_argument("--organization", type=str, help="Optional: organization namespace.") upload_parser.add_argument( "--filename", type=str, default=None, help="Optional: override individual object filename on S3." ) @@ -91,8 +94,10 @@ class WhoamiCommand(BaseUserCommand): print("Not logged in") exit() try: - user = self._api.whoami(token) + user, orgs = self._api.whoami(token) print(user) + if orgs: + print(ANSI.bold("orgs: "), ",".join(orgs)) except HTTPError as e: print(e) @@ -130,7 +135,7 @@ class ListObjsCommand(BaseUserCommand): print("Not logged in") exit(1) try: - objs = self._api.list_objs(token) + objs = self._api.list_objs(token, organization=self.args.organization) except HTTPError as e: print(e) exit(1) @@ -148,7 +153,7 @@ class DeleteObjCommand(BaseUserCommand): print("Not logged in") exit(1) try: - self._api.delete_obj(token, filename=self.args.filename) + self._api.delete_obj(token, filename=self.args.filename, organization=self.args.organization) except HTTPError as e: print(e) exit(1) @@ -195,8 +200,15 @@ class UploadCommand(BaseUserCommand): ) exit(1) + user, _ = self._api.whoami(token) + namespace = self.args.organization if self.args.organization is not None else user + for filepath, filename in files: - print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename))) + print( + "About to upload file {} to S3 under filename {} and namespace {}".format( + ANSI.bold(filepath), ANSI.bold(filename), ANSI.bold(namespace) + ) + ) choice = input("Proceed? [Y/n] ").lower() if not (choice == "" or choice == "y" or choice == "yes"): @@ -204,6 +216,8 @@ class UploadCommand(BaseUserCommand): exit() print(ANSI.bold("Uploading... This might take a while if files are large")) for filepath, filename in files: - access_url = self._api.presign_and_upload(token=token, filename=filename, filepath=filepath) + access_url = self._api.presign_and_upload( + token=token, filename=filename, filepath=filepath, organization=self.args.organization + ) print("Your file now lives at:") print(access_url) diff --git a/src/transformers/hf_api.py b/src/transformers/hf_api.py index 00395d685f..bf1ea4c727 100644 --- a/src/transformers/hf_api.py +++ b/src/transformers/hf_api.py @@ -17,7 +17,7 @@ import io import os from os.path import expanduser -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import requests from tqdm import tqdm @@ -109,7 +109,7 @@ class HfApi: d = r.json() return d["token"] - def whoami(self, token: str) -> str: + def whoami(self, token: str) -> Tuple[str, List[str]]: """ Call HF API to know "whoami" """ @@ -117,7 +117,7 @@ class HfApi: r = requests.get(path, headers={"authorization": "Bearer {}".format(token)}) r.raise_for_status() d = r.json() - return d["user"] + return d["user"], d["orgs"] def logout(self, token: str) -> None: """ @@ -127,24 +127,28 @@ class HfApi: r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}) r.raise_for_status() - def presign(self, token: str, filename: str) -> PresignedUrl: + def presign(self, token: str, filename: str, organization: Optional[str] = None) -> PresignedUrl: """ Call HF API to get a presigned url to upload `filename` to S3. """ path = "{}/api/presign".format(self.endpoint) - r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename}) + r = requests.post( + path, + headers={"authorization": "Bearer {}".format(token)}, + json={"filename": filename, "organization": organization}, + ) r.raise_for_status() d = r.json() return PresignedUrl(**d) - def presign_and_upload(self, token: str, filename: str, filepath: str) -> str: + def presign_and_upload(self, token: str, filename: str, filepath: str, organization: Optional[str] = None) -> str: """ Get a presigned url, then upload file to S3. Outputs: url: Read-only url for the stored file on S3. """ - urls = self.presign(token, filename=filename) + urls = self.presign(token, filename=filename, organization=organization) # streaming upload: # https://2.python-requests.org/en/master/user/advanced/#streaming-uploads # @@ -159,22 +163,27 @@ class HfApi: pf.close() return urls.access - def list_objs(self, token: str) -> List[S3Obj]: + def list_objs(self, token: str, organization: Optional[str] = None) -> List[S3Obj]: """ - Call HF API to list all stored files for user. + Call HF API to list all stored files for user (or one of their organizations). """ path = "{}/api/listObjs".format(self.endpoint) - r = requests.get(path, headers={"authorization": "Bearer {}".format(token)}) + params = {"organization": organization} if organization is not None else None + r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)}) r.raise_for_status() d = r.json() return [S3Obj(**x) for x in d] - def delete_obj(self, token: str, filename: str): + def delete_obj(self, token: str, filename: str, organization: Optional[str] = None): """ Call HF API to delete a file stored by user """ path = "{}/api/deleteObj".format(self.endpoint) - r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename}) + r = requests.delete( + path, + headers={"authorization": "Bearer {}".format(token)}, + json={"filename": filename, "organization": organization}, + ) r.raise_for_status() def model_list(self) -> List[ModelInfo]: diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index e1537bbfd4..6d3d82bcbf 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -67,8 +67,17 @@ class HfApiEndpointsTest(HfApiCommonTest): cls._api.delete_obj(token=cls._token, filename=FILE_KEY) def test_whoami(self): - user = self._api.whoami(token=self._token) + user, orgs = self._api.whoami(token=self._token) self.assertEqual(user, USER) + self.assertIsInstance(orgs, list) + + def test_presign_invalid_org(self): + with self.assertRaises(HTTPError): + _ = self._api.presign(token=self._token, filename="fake_org.txt", organization="fake") + + def test_presign_valid_org(self): + urls = self._api.presign(token=self._token, filename="valid_org.txt", organization="valid_org") + self.assertIsInstance(urls, PresignedUrl) def test_presign(self): for FILE_KEY, FILE_PATH in FILES: