Merge pull request #2044 from huggingface/cli_upload
CLI for authenticated file sharing
This commit is contained in:
10
setup.py
10
setup.py
@@ -36,6 +36,12 @@ To create the package for pypi.
|
||||
from io import open
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
extras = {
|
||||
'serving': ['uvicorn', 'fastapi']
|
||||
}
|
||||
extras['all'] = [package for package in extras.values()]
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="2.2.1",
|
||||
@@ -61,6 +67,10 @@ setup(
|
||||
"transformers=transformers.__main__:main",
|
||||
]
|
||||
},
|
||||
extras_require=extras,
|
||||
scripts=[
|
||||
'transformers-cli'
|
||||
],
|
||||
# python_requires='>=3.5.0',
|
||||
tests_require=['pytest'],
|
||||
classifiers=[
|
||||
|
||||
23
transformers-cli
Normal file
23
transformers-cli
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from transformers.commands.user import UserCommands
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser(description='Transformers CLI tool', usage='transformers-cli <command> [<args>]')
|
||||
commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
|
||||
|
||||
# Register commands
|
||||
UserCommands.register_subcommand(commands_parser)
|
||||
|
||||
# Let's go
|
||||
args = parser.parse_args()
|
||||
|
||||
if not hasattr(args, 'func'):
|
||||
parser.print_help()
|
||||
exit(1)
|
||||
|
||||
# Run
|
||||
service = args.func(args)
|
||||
service.run()
|
||||
12
transformers/commands/__init__.py
Normal file
12
transformers/commands/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from argparse import ArgumentParser
|
||||
|
||||
class BaseTransformersCLICommand(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
raise NotImplementedError()
|
||||
165
transformers/commands/user.py
Normal file
165
transformers/commands/user.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from argparse import ArgumentParser
|
||||
from getpass import getpass
|
||||
import os
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.hf_api import HfApi, HfFolder, HTTPError
|
||||
|
||||
|
||||
class UserCommands(BaseTransformersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
login_parser = parser.add_parser('login')
|
||||
login_parser.set_defaults(func=lambda args: LoginCommand(args))
|
||||
whoami_parser = parser.add_parser('whoami')
|
||||
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
|
||||
logout_parser = parser.add_parser('logout')
|
||||
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
|
||||
list_parser = parser.add_parser('ls')
|
||||
list_parser.set_defaults(func=lambda args: ListObjsCommand(args))
|
||||
# upload
|
||||
upload_parser = parser.add_parser('upload')
|
||||
upload_parser.add_argument('file', type=str, help='Local filepath of the file to upload.')
|
||||
upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override object filename on S3.')
|
||||
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
|
||||
|
||||
|
||||
|
||||
class ANSI:
|
||||
"""
|
||||
Helper for en.wikipedia.org/wiki/ANSI_escape_code
|
||||
"""
|
||||
_bold = u"\u001b[1m"
|
||||
_reset = u"\u001b[0m"
|
||||
@classmethod
|
||||
def bold(cls, s):
|
||||
return "{}{}{}".format(cls._bold, s, cls._reset)
|
||||
|
||||
|
||||
class BaseUserCommand:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self._api = HfApi()
|
||||
|
||||
|
||||
class LoginCommand(BaseUserCommand):
|
||||
def run(self):
|
||||
print("""
|
||||
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
|
||||
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
||||
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
|
||||
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
||||
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
|
||||
|
||||
""")
|
||||
username = input("Username: ")
|
||||
password = getpass()
|
||||
try:
|
||||
token = self._api.login(username, password)
|
||||
except HTTPError as e:
|
||||
# probably invalid credentials, display error message.
|
||||
print(e)
|
||||
exit(1)
|
||||
HfFolder.save_token(token)
|
||||
print("Login successful")
|
||||
print("Your token:", token, "\n")
|
||||
print("Your token has been saved to", HfFolder.path_token)
|
||||
|
||||
|
||||
class WhoamiCommand(BaseUserCommand):
|
||||
def run(self):
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit()
|
||||
try:
|
||||
user = self._api.whoami(token)
|
||||
print(user)
|
||||
except HTTPError as e:
|
||||
print(e)
|
||||
|
||||
|
||||
class LogoutCommand(BaseUserCommand):
|
||||
def run(self):
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit()
|
||||
HfFolder.delete_token()
|
||||
self._api.logout(token)
|
||||
print("Successfully logged out.")
|
||||
|
||||
|
||||
class ListObjsCommand(BaseUserCommand):
|
||||
def tabulate(self, rows, headers):
|
||||
# type: (List[List[Union[str, int]]], List[str]) -> str
|
||||
"""
|
||||
Inspired by:
|
||||
stackoverflow.com/a/8356620/593036
|
||||
stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
||||
"""
|
||||
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
||||
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
||||
lines = []
|
||||
lines.append(
|
||||
row_format.format(*headers)
|
||||
)
|
||||
lines.append(
|
||||
row_format.format(*["-" * w for w in col_widths])
|
||||
)
|
||||
for row in rows:
|
||||
lines.append(
|
||||
row_format.format(*row)
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
def run(self):
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit(1)
|
||||
try:
|
||||
objs = self._api.list_objs(token)
|
||||
except HTTPError as e:
|
||||
print(e)
|
||||
exit(1)
|
||||
if len(objs) == 0:
|
||||
print("No shared file yet")
|
||||
exit()
|
||||
rows = [ [
|
||||
obj.filename,
|
||||
obj.LastModified,
|
||||
obj.ETag,
|
||||
obj.Size
|
||||
] for obj in objs ]
|
||||
print(
|
||||
self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])
|
||||
)
|
||||
|
||||
|
||||
class UploadCommand(BaseUserCommand):
|
||||
def run(self):
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit(1)
|
||||
filepath = os.path.join(os.getcwd(), self.args.file)
|
||||
filename = self.args.filename if self.args.filename is not None else os.path.basename(filepath)
|
||||
print(
|
||||
"About to upload file {} to S3 under filename {}".format(
|
||||
ANSI.bold(filepath), ANSI.bold(filename)
|
||||
)
|
||||
)
|
||||
|
||||
choice = input("Proceed? [Y/n] ").lower()
|
||||
if not(choice == "" or choice == "y" or choice == "yes"):
|
||||
print("Abort")
|
||||
exit()
|
||||
print(
|
||||
ANSI.bold("Uploading... This might take a while if file is large")
|
||||
)
|
||||
access_url = self._api.presign_and_upload(
|
||||
token=token, filename=filename, filepath=filepath
|
||||
)
|
||||
print("Your file now lives at:")
|
||||
print(access_url)
|
||||
196
transformers/hf_api.py
Normal file
196
transformers/hf_api.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
from os.path import expanduser
|
||||
import six
|
||||
|
||||
import requests
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
ENDPOINT = "https://huggingface.co"
|
||||
|
||||
class S3Obj:
|
||||
def __init__(
|
||||
self,
|
||||
filename, # type: str
|
||||
LastModified, # type: str
|
||||
ETag, # type: str
|
||||
Size, # type: int
|
||||
**kwargs
|
||||
):
|
||||
self.filename = filename
|
||||
self.LastModified = LastModified
|
||||
self.ETag = ETag
|
||||
self.Size = Size
|
||||
|
||||
|
||||
class PresignedUrl:
|
||||
def __init__(
|
||||
self,
|
||||
write, # type: str
|
||||
access, # type: str
|
||||
type, # type: str
|
||||
**kwargs
|
||||
):
|
||||
self.write = write
|
||||
self.access = access
|
||||
self.type = type # mime-type to send to S3.
|
||||
|
||||
|
||||
class HfApi:
|
||||
def __init__(self, endpoint=None):
|
||||
self.endpoint = endpoint if endpoint is not None else ENDPOINT
|
||||
|
||||
def login(
|
||||
self,
|
||||
username, # type: str
|
||||
password, # type: str
|
||||
):
|
||||
# type: (...) -> str
|
||||
"""
|
||||
Call HF API to sign in a user and get a token if credentials are valid.
|
||||
|
||||
Outputs:
|
||||
token if credentials are valid
|
||||
|
||||
Throws:
|
||||
requests.exceptions.HTTPError if credentials are invalid
|
||||
"""
|
||||
path = "{}/api/login".format(self.endpoint)
|
||||
r = requests.post(path, json={"username": username, "password": password})
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
return d["token"]
|
||||
|
||||
def whoami(
|
||||
self,
|
||||
token, # type: str
|
||||
):
|
||||
# type: (...) -> str
|
||||
"""
|
||||
Call HF API to know "whoami"
|
||||
"""
|
||||
path = "{}/api/whoami".format(self.endpoint)
|
||||
r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
return d["user"]
|
||||
|
||||
def logout(self, token):
|
||||
# type: (...) -> void
|
||||
"""
|
||||
Call HF API to log out.
|
||||
"""
|
||||
path = "{}/api/logout".format(self.endpoint)
|
||||
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
|
||||
r.raise_for_status()
|
||||
|
||||
def presign(self, token, filename):
|
||||
# type: (...) -> PresignedUrl
|
||||
"""
|
||||
Call HF API to get a presigned url to upload `filename` to S3.
|
||||
"""
|
||||
path = "{}/api/presign".format(self.endpoint)
|
||||
r = requests.post(
|
||||
path,
|
||||
headers={"authorization": "Bearer {}".format(token)},
|
||||
json={"filename": filename},
|
||||
)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
return PresignedUrl(**d)
|
||||
|
||||
def presign_and_upload(self, token, filename, filepath):
|
||||
# type: (...) -> str
|
||||
"""
|
||||
Get a presigned url, then upload file to S3.
|
||||
|
||||
Outputs:
|
||||
url: Read-only url for the stored file on S3.
|
||||
"""
|
||||
urls = self.presign(token, filename=filename)
|
||||
# streaming upload:
|
||||
# https://2.python-requests.org/en/master/user/advanced/#streaming-uploads
|
||||
#
|
||||
# Even though we presign with the correct content-type,
|
||||
# the client still has to specify it when uploading the file.
|
||||
with open(filepath, "rb") as f:
|
||||
r = requests.put(urls.write, data=f, headers={
|
||||
"content-type": urls.type,
|
||||
})
|
||||
r.raise_for_status()
|
||||
return urls.access
|
||||
|
||||
def list_objs(self, token):
|
||||
# type: (...) -> List[S3Obj]
|
||||
"""
|
||||
Call HF API to list all stored files for user.
|
||||
"""
|
||||
path = "{}/api/listObjs".format(self.endpoint)
|
||||
r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
return [S3Obj(**x) for x in d]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class HfFolder:
|
||||
path_token = expanduser("~/.huggingface/token")
|
||||
|
||||
@classmethod
|
||||
def save_token(cls, token):
|
||||
"""
|
||||
Save token, creating folder as needed.
|
||||
"""
|
||||
if six.PY3:
|
||||
os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
|
||||
else:
|
||||
# Python 2
|
||||
try:
|
||||
os.makedirs(os.path.dirname(cls.path_token))
|
||||
except OSError as e:
|
||||
if e.errno != os.errno.EEXIST:
|
||||
raise e
|
||||
pass
|
||||
with open(cls.path_token, 'w+') as f:
|
||||
f.write(token)
|
||||
|
||||
@classmethod
|
||||
def get_token(cls):
|
||||
"""
|
||||
Get token or None if not existent.
|
||||
"""
|
||||
try:
|
||||
with open(cls.path_token, 'r') as f:
|
||||
return f.read()
|
||||
except:
|
||||
# this is too wide. When Py2 is dead use:
|
||||
# `except FileNotFoundError:` instead
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def delete_token(cls):
|
||||
"""
|
||||
Delete token.
|
||||
Do not fail if token does not exist.
|
||||
"""
|
||||
try:
|
||||
os.remove(cls.path_token)
|
||||
except:
|
||||
return
|
||||
102
transformers/tests/hf_api_test.py
Normal file
102
transformers/tests/hf_api_test.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import six
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from transformers.hf_api import HfApi, S3Obj, PresignedUrl, HfFolder, HTTPError
|
||||
|
||||
USER = "__DUMMY_TRANSFORMERS_USER__"
|
||||
PASS = "__DUMMY_TRANSFORMERS_PASS__"
|
||||
FILE_KEY = "Test-{}.txt".format(int(time.time()))
|
||||
FILE_PATH = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "fixtures/input.txt"
|
||||
)
|
||||
|
||||
|
||||
|
||||
class HfApiCommonTest(unittest.TestCase):
|
||||
_api = HfApi(endpoint="https://moon-staging.huggingface.co")
|
||||
|
||||
|
||||
class HfApiLoginTest(HfApiCommonTest):
|
||||
def test_login_invalid(self):
|
||||
with self.assertRaises(HTTPError):
|
||||
self._api.login(username=USER, password="fake")
|
||||
|
||||
def test_login_valid(self):
|
||||
token = self._api.login(username=USER, password=PASS)
|
||||
self.assertIsInstance(token, six.string_types)
|
||||
|
||||
|
||||
class HfApiEndpointsTest(HfApiCommonTest):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Share this valid token in all tests below.
|
||||
"""
|
||||
cls._token = cls._api.login(username=USER, password=PASS)
|
||||
|
||||
def test_whoami(self):
|
||||
user = self._api.whoami(token=self._token)
|
||||
self.assertEqual(user, USER)
|
||||
|
||||
def test_presign(self):
|
||||
urls = self._api.presign(token=self._token, filename=FILE_KEY)
|
||||
self.assertIsInstance(urls, PresignedUrl)
|
||||
self.assertEqual(urls.type, "text/plain")
|
||||
|
||||
def test_presign_and_upload(self):
|
||||
access_url = self._api.presign_and_upload(
|
||||
token=self._token, filename=FILE_KEY, filepath=FILE_PATH
|
||||
)
|
||||
self.assertIsInstance(access_url, six.string_types)
|
||||
|
||||
def test_list_objs(self):
|
||||
objs = self._api.list_objs(token=self._token)
|
||||
self.assertIsInstance(objs, list)
|
||||
if len(objs) > 0:
|
||||
o = objs[-1]
|
||||
self.assertIsInstance(o, S3Obj)
|
||||
|
||||
|
||||
|
||||
class HfFolderTest(unittest.TestCase):
|
||||
def test_token_workflow(self):
|
||||
"""
|
||||
Test the whole token save/get/delete workflow,
|
||||
with the desired behavior with respect to non-existent tokens.
|
||||
"""
|
||||
token = "token-{}".format(int(time.time()))
|
||||
HfFolder.save_token(token)
|
||||
self.assertEqual(
|
||||
HfFolder.get_token(),
|
||||
token
|
||||
)
|
||||
HfFolder.delete_token()
|
||||
HfFolder.delete_token()
|
||||
# ^^ not an error, we test that the
|
||||
# second call does not fail.
|
||||
self.assertEqual(
|
||||
HfFolder.get_token(),
|
||||
None
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user