Switch from comments to annotations for types.
This commit is contained in:
@@ -27,14 +27,7 @@ ENDPOINT = "https://huggingface.co"
|
||||
|
||||
|
||||
class S3Obj:
|
||||
def __init__(
|
||||
self,
|
||||
filename, # type: str
|
||||
LastModified, # type: str
|
||||
ETag, # type: str
|
||||
Size, # type: int
|
||||
**kwargs
|
||||
):
|
||||
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
|
||||
self.filename = filename
|
||||
self.LastModified = LastModified
|
||||
self.ETag = ETag
|
||||
@@ -42,13 +35,7 @@ class S3Obj:
|
||||
|
||||
|
||||
class PresignedUrl:
|
||||
def __init__(
|
||||
self,
|
||||
write, # type: str
|
||||
access, # type: str
|
||||
type, # type: str
|
||||
**kwargs
|
||||
):
|
||||
def __init__(self, write: str, access: str, type: str, **kwargs):
|
||||
self.write = write
|
||||
self.access = access
|
||||
self.type = type # mime-type to send to S3.
|
||||
@@ -58,12 +45,7 @@ class HfApi:
|
||||
def __init__(self, endpoint=None):
|
||||
self.endpoint = endpoint if endpoint is not None else ENDPOINT
|
||||
|
||||
def login(
|
||||
self,
|
||||
username, # type: str
|
||||
password, # type: str
|
||||
):
|
||||
# type: (...) -> str
|
||||
def login(self, username: str, password: str) -> str:
|
||||
"""
|
||||
Call HF API to sign in a user and get a token if credentials are valid.
|
||||
|
||||
@@ -79,10 +61,7 @@ class HfApi:
|
||||
d = r.json()
|
||||
return d["token"]
|
||||
|
||||
def whoami(
|
||||
self, token, # type: str
|
||||
):
|
||||
# type: (...) -> str
|
||||
def whoami(self, token: str) -> str:
|
||||
"""
|
||||
Call HF API to know "whoami"
|
||||
"""
|
||||
@@ -92,8 +71,7 @@ class HfApi:
|
||||
d = r.json()
|
||||
return d["user"]
|
||||
|
||||
def logout(self, token):
|
||||
# type: (...) -> None
|
||||
def logout(self, token: str) -> None:
|
||||
"""
|
||||
Call HF API to log out.
|
||||
"""
|
||||
@@ -101,19 +79,17 @@ class HfApi:
|
||||
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
|
||||
r.raise_for_status()
|
||||
|
||||
def presign(self, token, filename):
|
||||
# type: (...) -> PresignedUrl
|
||||
def presign(self, token: str, filename) -> PresignedUrl:
|
||||
"""
|
||||
Call HF API to get a presigned url to upload `filename` to S3.
|
||||
"""
|
||||
path = "{}/api/presign".format(self.endpoint)
|
||||
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename},)
|
||||
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
return PresignedUrl(**d)
|
||||
|
||||
def presign_and_upload(self, token, filename, filepath):
|
||||
# type: (...) -> str
|
||||
def presign_and_upload(self, token: str, filename, filepath) -> str:
|
||||
"""
|
||||
Get a presigned url, then upload file to S3.
|
||||
|
||||
@@ -157,7 +133,7 @@ class TqdmProgressFileReader:
|
||||
|
||||
def __init__(self, f: io.BufferedReader):
|
||||
self.f = f
|
||||
self.total_size = os.fstat(f.fileno()).st_size # type: int
|
||||
self.total_size = os.fstat(f.fileno()).st_size
|
||||
self.pbar = tqdm(total=self.total_size, leave=False)
|
||||
self.read = f.read
|
||||
f.read = self._read
|
||||
|
||||
Reference in New Issue
Block a user