From 96fa9a8a70a52221446b0b887f99c90c5ce31eeb Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 4 Dec 2019 17:22:50 -0500 Subject: [PATCH] Python 2 + Post mime-type to S3 --- transformers/hf_api.py | 74 ++++++++++++++++++++++++------- transformers/tests/hf_api_test.py | 20 ++++++--- 2 files changed, 73 insertions(+), 21 deletions(-) diff --git a/transformers/hf_api.py b/transformers/hf_api.py index 238762ebf8..c21592a838 100644 --- a/transformers/hf_api.py +++ b/transformers/hf_api.py @@ -14,9 +14,9 @@ # limitations under the License. from __future__ import absolute_import, division, print_function -from typing import List, NamedTuple import os from os.path import expanduser +import six import requests from requests.exceptions import HTTPError @@ -24,23 +24,43 @@ from requests.exceptions import HTTPError ENDPOINT = "https://huggingface.co" class S3Obj: - def __init__(self, filename: str, LastModified: str, ETag: str, Size: int): + def __init__( + self, + filename, # type: str + LastModified, # type: str + ETag, # type: str + Size, # type: int + **kwargs + ): self.filename = filename self.LastModified = LastModified self.ETag = ETag self.Size = Size -class PresignedUrl(NamedTuple): - write: str - access: str +class PresignedUrl: + def __init__( + self, + write, # type: str + access, # type: str + type, # type: str + **kwargs + ): + self.write = write + self.access = access + self.type = type # mime-type to send to S3. class HfApi: def __init__(self, endpoint=None): self.endpoint = endpoint if endpoint is not None else ENDPOINT - def login(self, username: str, password: str) -> str: + def login( + self, + username, # type: str + password, # type: str + ): + # type: (...) -> str """ Call HF API to sign in a user and get a token if credentials are valid. @@ -56,7 +76,11 @@ class HfApi: d = r.json() return d["token"] - def whoami(self, token: str) -> str: + def whoami( + self, + token, # type: str + ): + # type: (...) -> str """ Call HF API to know "whoami" """ @@ -66,7 +90,8 @@ class HfApi: d = r.json() return d["user"] - def logout(self, token: str): + def logout(self, token): + # type: (...) -> void """ Call HF API to log out. """ @@ -74,7 +99,8 @@ class HfApi: r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}) r.raise_for_status() - def presign(self, token: str, filename: str) -> PresignedUrl: + def presign(self, token, filename): + # type: (...) -> PresignedUrl """ Call HF API to get a presigned url to upload `filename` to S3. """ @@ -88,7 +114,8 @@ class HfApi: d = r.json() return PresignedUrl(**d) - def presign_and_upload(self, token: str, filename: str, filepath: str) -> str: + def presign_and_upload(self, token, filename, filepath): + # type: (...) -> str """ Get a presigned url, then upload file to S3. @@ -98,12 +125,18 @@ class HfApi: urls = self.presign(token, filename=filename) # streaming upload: # https://2.python-requests.org/en/master/user/advanced/#streaming-uploads + # + # Even though we presign with the correct content-type, + # the client still has to specify it when uploading the file. with open(filepath, "rb") as f: - r = requests.put(urls.write, data=f) + r = requests.put(urls.write, data=f, headers={ + "content-type": urls.type, + }) r.raise_for_status() return urls.access - def list_objs(self, token: str) -> List[S3Obj]: + def list_objs(self, token): + # type: (...) -> List[S3Obj] """ Call HF API to list all stored files for user. """ @@ -121,11 +154,20 @@ class HfFolder: path_token = expanduser("~/.huggingface/token") @classmethod - def save_token(cls, token: str): + def save_token(cls, token): """ Save token, creating folder as needed. """ - os.makedirs(os.path.dirname(cls.path_token), exist_ok=True) + if six.PY3: + os.makedirs(os.path.dirname(cls.path_token), exist_ok=True) + else: + # Python 2 + try: + os.makedirs(os.path.dirname(cls.path_token)) + except OSError as e: + if e.errno != os.errno.EEXIST: + raise e + pass with open(cls.path_token, 'w+') as f: f.write(token) @@ -137,7 +179,9 @@ class HfFolder: try: with open(cls.path_token, 'r') as f: return f.read() - except FileNotFoundError: + except: + # this is too wide. When Py2 is dead use: + # `except FileNotFoundError:` instead return None @classmethod diff --git a/transformers/tests/hf_api_test.py b/transformers/tests/hf_api_test.py index 59822344ba..92d41b6dff 100644 --- a/transformers/tests/hf_api_test.py +++ b/transformers/tests/hf_api_test.py @@ -15,6 +15,7 @@ from __future__ import absolute_import, division, print_function import os +import six import time import unittest @@ -40,7 +41,7 @@ class HfApiLoginTest(HfApiCommonTest): def test_login_valid(self): token = self._api.login(username=USER, password=PASS) - self.assertIsInstance(token, str) + self.assertIsInstance(token, six.string_types) class HfApiEndpointsTest(HfApiCommonTest): @@ -56,19 +57,22 @@ class HfApiEndpointsTest(HfApiCommonTest): self.assertEqual(user, USER) def test_presign(self): - url = self._api.presign(token=self._token, filename=FILE_KEY) - self.assertIsInstance(url, PresignedUrl) + urls = self._api.presign(token=self._token, filename=FILE_KEY) + self.assertIsInstance(urls, PresignedUrl) + self.assertEqual(urls.type, "text/plain") def test_presign_and_upload(self): access_url = self._api.presign_and_upload( token=self._token, filename=FILE_KEY, filepath=FILE_PATH ) - self.assertIsInstance(access_url, str) + self.assertIsInstance(access_url, six.string_types) def test_list_objs(self): objs = self._api.list_objs(token=self._token) - o = objs[-1] - self.assertIsInstance(o, S3Obj) + self.assertIsInstance(objs, list) + if len(objs) > 0: + o = objs[-1] + self.assertIsInstance(o, S3Obj) @@ -92,3 +96,7 @@ class HfFolderTest(unittest.TestCase): HfFolder.get_token(), None ) + + +if __name__ == "__main__": + unittest.main()