Merge pull request #2044 from huggingface/cli_upload
CLI for authenticated file sharing
This commit is contained in:
10
setup.py
10
setup.py
@@ -36,6 +36,12 @@ To create the package for pypi.
|
|||||||
from io import open
|
from io import open
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
|
extras = {
|
||||||
|
'serving': ['uvicorn', 'fastapi']
|
||||||
|
}
|
||||||
|
extras['all'] = [package for package in extras.values()]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="transformers",
|
name="transformers",
|
||||||
version="2.2.1",
|
version="2.2.1",
|
||||||
@@ -61,6 +67,10 @@ setup(
|
|||||||
"transformers=transformers.__main__:main",
|
"transformers=transformers.__main__:main",
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
extras_require=extras,
|
||||||
|
scripts=[
|
||||||
|
'transformers-cli'
|
||||||
|
],
|
||||||
# python_requires='>=3.5.0',
|
# python_requires='>=3.5.0',
|
||||||
tests_require=['pytest'],
|
tests_require=['pytest'],
|
||||||
classifiers=[
|
classifiers=[
|
||||||
|
|||||||
23
transformers-cli
Normal file
23
transformers-cli
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
from transformers.commands.user import UserCommands
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = ArgumentParser(description='Transformers CLI tool', usage='transformers-cli <command> [<args>]')
|
||||||
|
commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
|
||||||
|
|
||||||
|
# Register commands
|
||||||
|
UserCommands.register_subcommand(commands_parser)
|
||||||
|
|
||||||
|
# Let's go
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not hasattr(args, 'func'):
|
||||||
|
parser.print_help()
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Run
|
||||||
|
service = args.func(args)
|
||||||
|
service.run()
|
||||||
12
transformers/commands/__init__.py
Normal file
12
transformers/commands/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
class BaseTransformersCLICommand(ABC):
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def register_subcommand(parser: ArgumentParser):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self):
|
||||||
|
raise NotImplementedError()
|
||||||
165
transformers/commands/user.py
Normal file
165
transformers/commands/user.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
from argparse import ArgumentParser
|
||||||
|
from getpass import getpass
|
||||||
|
import os
|
||||||
|
|
||||||
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
|
from transformers.hf_api import HfApi, HfFolder, HTTPError
|
||||||
|
|
||||||
|
|
||||||
|
class UserCommands(BaseTransformersCLICommand):
|
||||||
|
@staticmethod
|
||||||
|
def register_subcommand(parser: ArgumentParser):
|
||||||
|
login_parser = parser.add_parser('login')
|
||||||
|
login_parser.set_defaults(func=lambda args: LoginCommand(args))
|
||||||
|
whoami_parser = parser.add_parser('whoami')
|
||||||
|
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
|
||||||
|
logout_parser = parser.add_parser('logout')
|
||||||
|
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
|
||||||
|
list_parser = parser.add_parser('ls')
|
||||||
|
list_parser.set_defaults(func=lambda args: ListObjsCommand(args))
|
||||||
|
# upload
|
||||||
|
upload_parser = parser.add_parser('upload')
|
||||||
|
upload_parser.add_argument('file', type=str, help='Local filepath of the file to upload.')
|
||||||
|
upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override object filename on S3.')
|
||||||
|
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ANSI:
|
||||||
|
"""
|
||||||
|
Helper for en.wikipedia.org/wiki/ANSI_escape_code
|
||||||
|
"""
|
||||||
|
_bold = u"\u001b[1m"
|
||||||
|
_reset = u"\u001b[0m"
|
||||||
|
@classmethod
|
||||||
|
def bold(cls, s):
|
||||||
|
return "{}{}{}".format(cls._bold, s, cls._reset)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseUserCommand:
|
||||||
|
def __init__(self, args):
|
||||||
|
self.args = args
|
||||||
|
self._api = HfApi()
|
||||||
|
|
||||||
|
|
||||||
|
class LoginCommand(BaseUserCommand):
|
||||||
|
def run(self):
|
||||||
|
print("""
|
||||||
|
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
|
||||||
|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
||||||
|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
|
||||||
|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
||||||
|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
|
||||||
|
|
||||||
|
""")
|
||||||
|
username = input("Username: ")
|
||||||
|
password = getpass()
|
||||||
|
try:
|
||||||
|
token = self._api.login(username, password)
|
||||||
|
except HTTPError as e:
|
||||||
|
# probably invalid credentials, display error message.
|
||||||
|
print(e)
|
||||||
|
exit(1)
|
||||||
|
HfFolder.save_token(token)
|
||||||
|
print("Login successful")
|
||||||
|
print("Your token:", token, "\n")
|
||||||
|
print("Your token has been saved to", HfFolder.path_token)
|
||||||
|
|
||||||
|
|
||||||
|
class WhoamiCommand(BaseUserCommand):
|
||||||
|
def run(self):
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
if token is None:
|
||||||
|
print("Not logged in")
|
||||||
|
exit()
|
||||||
|
try:
|
||||||
|
user = self._api.whoami(token)
|
||||||
|
print(user)
|
||||||
|
except HTTPError as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
class LogoutCommand(BaseUserCommand):
|
||||||
|
def run(self):
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
if token is None:
|
||||||
|
print("Not logged in")
|
||||||
|
exit()
|
||||||
|
HfFolder.delete_token()
|
||||||
|
self._api.logout(token)
|
||||||
|
print("Successfully logged out.")
|
||||||
|
|
||||||
|
|
||||||
|
class ListObjsCommand(BaseUserCommand):
|
||||||
|
def tabulate(self, rows, headers):
|
||||||
|
# type: (List[List[Union[str, int]]], List[str]) -> str
|
||||||
|
"""
|
||||||
|
Inspired by:
|
||||||
|
stackoverflow.com/a/8356620/593036
|
||||||
|
stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
||||||
|
"""
|
||||||
|
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
||||||
|
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
||||||
|
lines = []
|
||||||
|
lines.append(
|
||||||
|
row_format.format(*headers)
|
||||||
|
)
|
||||||
|
lines.append(
|
||||||
|
row_format.format(*["-" * w for w in col_widths])
|
||||||
|
)
|
||||||
|
for row in rows:
|
||||||
|
lines.append(
|
||||||
|
row_format.format(*row)
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
if token is None:
|
||||||
|
print("Not logged in")
|
||||||
|
exit(1)
|
||||||
|
try:
|
||||||
|
objs = self._api.list_objs(token)
|
||||||
|
except HTTPError as e:
|
||||||
|
print(e)
|
||||||
|
exit(1)
|
||||||
|
if len(objs) == 0:
|
||||||
|
print("No shared file yet")
|
||||||
|
exit()
|
||||||
|
rows = [ [
|
||||||
|
obj.filename,
|
||||||
|
obj.LastModified,
|
||||||
|
obj.ETag,
|
||||||
|
obj.Size
|
||||||
|
] for obj in objs ]
|
||||||
|
print(
|
||||||
|
self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UploadCommand(BaseUserCommand):
|
||||||
|
def run(self):
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
if token is None:
|
||||||
|
print("Not logged in")
|
||||||
|
exit(1)
|
||||||
|
filepath = os.path.join(os.getcwd(), self.args.file)
|
||||||
|
filename = self.args.filename if self.args.filename is not None else os.path.basename(filepath)
|
||||||
|
print(
|
||||||
|
"About to upload file {} to S3 under filename {}".format(
|
||||||
|
ANSI.bold(filepath), ANSI.bold(filename)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
choice = input("Proceed? [Y/n] ").lower()
|
||||||
|
if not(choice == "" or choice == "y" or choice == "yes"):
|
||||||
|
print("Abort")
|
||||||
|
exit()
|
||||||
|
print(
|
||||||
|
ANSI.bold("Uploading... This might take a while if file is large")
|
||||||
|
)
|
||||||
|
access_url = self._api.presign_and_upload(
|
||||||
|
token=token, filename=filename, filepath=filepath
|
||||||
|
)
|
||||||
|
print("Your file now lives at:")
|
||||||
|
print(access_url)
|
||||||
196
transformers/hf_api.py
Normal file
196
transformers/hf_api.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
from os.path import expanduser
|
||||||
|
import six
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
ENDPOINT = "https://huggingface.co"
|
||||||
|
|
||||||
|
class S3Obj:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filename, # type: str
|
||||||
|
LastModified, # type: str
|
||||||
|
ETag, # type: str
|
||||||
|
Size, # type: int
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.filename = filename
|
||||||
|
self.LastModified = LastModified
|
||||||
|
self.ETag = ETag
|
||||||
|
self.Size = Size
|
||||||
|
|
||||||
|
|
||||||
|
class PresignedUrl:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
write, # type: str
|
||||||
|
access, # type: str
|
||||||
|
type, # type: str
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.write = write
|
||||||
|
self.access = access
|
||||||
|
self.type = type # mime-type to send to S3.
|
||||||
|
|
||||||
|
|
||||||
|
class HfApi:
|
||||||
|
def __init__(self, endpoint=None):
|
||||||
|
self.endpoint = endpoint if endpoint is not None else ENDPOINT
|
||||||
|
|
||||||
|
def login(
|
||||||
|
self,
|
||||||
|
username, # type: str
|
||||||
|
password, # type: str
|
||||||
|
):
|
||||||
|
# type: (...) -> str
|
||||||
|
"""
|
||||||
|
Call HF API to sign in a user and get a token if credentials are valid.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
token if credentials are valid
|
||||||
|
|
||||||
|
Throws:
|
||||||
|
requests.exceptions.HTTPError if credentials are invalid
|
||||||
|
"""
|
||||||
|
path = "{}/api/login".format(self.endpoint)
|
||||||
|
r = requests.post(path, json={"username": username, "password": password})
|
||||||
|
r.raise_for_status()
|
||||||
|
d = r.json()
|
||||||
|
return d["token"]
|
||||||
|
|
||||||
|
def whoami(
|
||||||
|
self,
|
||||||
|
token, # type: str
|
||||||
|
):
|
||||||
|
# type: (...) -> str
|
||||||
|
"""
|
||||||
|
Call HF API to know "whoami"
|
||||||
|
"""
|
||||||
|
path = "{}/api/whoami".format(self.endpoint)
|
||||||
|
r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
|
||||||
|
r.raise_for_status()
|
||||||
|
d = r.json()
|
||||||
|
return d["user"]
|
||||||
|
|
||||||
|
def logout(self, token):
|
||||||
|
# type: (...) -> void
|
||||||
|
"""
|
||||||
|
Call HF API to log out.
|
||||||
|
"""
|
||||||
|
path = "{}/api/logout".format(self.endpoint)
|
||||||
|
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
def presign(self, token, filename):
|
||||||
|
# type: (...) -> PresignedUrl
|
||||||
|
"""
|
||||||
|
Call HF API to get a presigned url to upload `filename` to S3.
|
||||||
|
"""
|
||||||
|
path = "{}/api/presign".format(self.endpoint)
|
||||||
|
r = requests.post(
|
||||||
|
path,
|
||||||
|
headers={"authorization": "Bearer {}".format(token)},
|
||||||
|
json={"filename": filename},
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
d = r.json()
|
||||||
|
return PresignedUrl(**d)
|
||||||
|
|
||||||
|
def presign_and_upload(self, token, filename, filepath):
|
||||||
|
# type: (...) -> str
|
||||||
|
"""
|
||||||
|
Get a presigned url, then upload file to S3.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
url: Read-only url for the stored file on S3.
|
||||||
|
"""
|
||||||
|
urls = self.presign(token, filename=filename)
|
||||||
|
# streaming upload:
|
||||||
|
# https://2.python-requests.org/en/master/user/advanced/#streaming-uploads
|
||||||
|
#
|
||||||
|
# Even though we presign with the correct content-type,
|
||||||
|
# the client still has to specify it when uploading the file.
|
||||||
|
with open(filepath, "rb") as f:
|
||||||
|
r = requests.put(urls.write, data=f, headers={
|
||||||
|
"content-type": urls.type,
|
||||||
|
})
|
||||||
|
r.raise_for_status()
|
||||||
|
return urls.access
|
||||||
|
|
||||||
|
def list_objs(self, token):
|
||||||
|
# type: (...) -> List[S3Obj]
|
||||||
|
"""
|
||||||
|
Call HF API to list all stored files for user.
|
||||||
|
"""
|
||||||
|
path = "{}/api/listObjs".format(self.endpoint)
|
||||||
|
r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
|
||||||
|
r.raise_for_status()
|
||||||
|
d = r.json()
|
||||||
|
return [S3Obj(**x) for x in d]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class HfFolder:
|
||||||
|
path_token = expanduser("~/.huggingface/token")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save_token(cls, token):
|
||||||
|
"""
|
||||||
|
Save token, creating folder as needed.
|
||||||
|
"""
|
||||||
|
if six.PY3:
|
||||||
|
os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
|
||||||
|
else:
|
||||||
|
# Python 2
|
||||||
|
try:
|
||||||
|
os.makedirs(os.path.dirname(cls.path_token))
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno != os.errno.EEXIST:
|
||||||
|
raise e
|
||||||
|
pass
|
||||||
|
with open(cls.path_token, 'w+') as f:
|
||||||
|
f.write(token)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_token(cls):
|
||||||
|
"""
|
||||||
|
Get token or None if not existent.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(cls.path_token, 'r') as f:
|
||||||
|
return f.read()
|
||||||
|
except:
|
||||||
|
# this is too wide. When Py2 is dead use:
|
||||||
|
# `except FileNotFoundError:` instead
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_token(cls):
|
||||||
|
"""
|
||||||
|
Delete token.
|
||||||
|
Do not fail if token does not exist.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
os.remove(cls.path_token)
|
||||||
|
except:
|
||||||
|
return
|
||||||
102
transformers/tests/hf_api_test.py
Normal file
102
transformers/tests/hf_api_test.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import six
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers.hf_api import HfApi, S3Obj, PresignedUrl, HfFolder, HTTPError
|
||||||
|
|
||||||
|
USER = "__DUMMY_TRANSFORMERS_USER__"
|
||||||
|
PASS = "__DUMMY_TRANSFORMERS_PASS__"
|
||||||
|
FILE_KEY = "Test-{}.txt".format(int(time.time()))
|
||||||
|
FILE_PATH = os.path.join(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)), "fixtures/input.txt"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class HfApiCommonTest(unittest.TestCase):
|
||||||
|
_api = HfApi(endpoint="https://moon-staging.huggingface.co")
|
||||||
|
|
||||||
|
|
||||||
|
class HfApiLoginTest(HfApiCommonTest):
|
||||||
|
def test_login_invalid(self):
|
||||||
|
with self.assertRaises(HTTPError):
|
||||||
|
self._api.login(username=USER, password="fake")
|
||||||
|
|
||||||
|
def test_login_valid(self):
|
||||||
|
token = self._api.login(username=USER, password=PASS)
|
||||||
|
self.assertIsInstance(token, six.string_types)
|
||||||
|
|
||||||
|
|
||||||
|
class HfApiEndpointsTest(HfApiCommonTest):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
"""
|
||||||
|
Share this valid token in all tests below.
|
||||||
|
"""
|
||||||
|
cls._token = cls._api.login(username=USER, password=PASS)
|
||||||
|
|
||||||
|
def test_whoami(self):
|
||||||
|
user = self._api.whoami(token=self._token)
|
||||||
|
self.assertEqual(user, USER)
|
||||||
|
|
||||||
|
def test_presign(self):
|
||||||
|
urls = self._api.presign(token=self._token, filename=FILE_KEY)
|
||||||
|
self.assertIsInstance(urls, PresignedUrl)
|
||||||
|
self.assertEqual(urls.type, "text/plain")
|
||||||
|
|
||||||
|
def test_presign_and_upload(self):
|
||||||
|
access_url = self._api.presign_and_upload(
|
||||||
|
token=self._token, filename=FILE_KEY, filepath=FILE_PATH
|
||||||
|
)
|
||||||
|
self.assertIsInstance(access_url, six.string_types)
|
||||||
|
|
||||||
|
def test_list_objs(self):
|
||||||
|
objs = self._api.list_objs(token=self._token)
|
||||||
|
self.assertIsInstance(objs, list)
|
||||||
|
if len(objs) > 0:
|
||||||
|
o = objs[-1]
|
||||||
|
self.assertIsInstance(o, S3Obj)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class HfFolderTest(unittest.TestCase):
|
||||||
|
def test_token_workflow(self):
|
||||||
|
"""
|
||||||
|
Test the whole token save/get/delete workflow,
|
||||||
|
with the desired behavior with respect to non-existent tokens.
|
||||||
|
"""
|
||||||
|
token = "token-{}".format(int(time.time()))
|
||||||
|
HfFolder.save_token(token)
|
||||||
|
self.assertEqual(
|
||||||
|
HfFolder.get_token(),
|
||||||
|
token
|
||||||
|
)
|
||||||
|
HfFolder.delete_token()
|
||||||
|
HfFolder.delete_token()
|
||||||
|
# ^^ not an error, we test that the
|
||||||
|
# second call does not fail.
|
||||||
|
self.assertEqual(
|
||||||
|
HfFolder.get_token(),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user