[cli] Uploads: add progress bar (#2078)
* [cli] Uploads: add progress bar see https://github.com/huggingface/transformers/pull/2044#discussion_r354057827 for context * rename + documentation * Add auto-referential comment
This commit is contained in:
@@ -16,10 +16,11 @@ from __future__ import absolute_import, division, print_function
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from os.path import expanduser
|
from os.path import expanduser
|
||||||
import six
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import six
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
ENDPOINT = "https://huggingface.co"
|
ENDPOINT = "https://huggingface.co"
|
||||||
|
|
||||||
@@ -129,10 +130,13 @@ class HfApi:
|
|||||||
# Even though we presign with the correct content-type,
|
# Even though we presign with the correct content-type,
|
||||||
# the client still has to specify it when uploading the file.
|
# the client still has to specify it when uploading the file.
|
||||||
with open(filepath, "rb") as f:
|
with open(filepath, "rb") as f:
|
||||||
|
pf = TqdmProgressFileReader(f)
|
||||||
|
|
||||||
r = requests.put(urls.write, data=f, headers={
|
r = requests.put(urls.write, data=f, headers={
|
||||||
"content-type": urls.type,
|
"content-type": urls.type,
|
||||||
})
|
})
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
pf.close()
|
||||||
return urls.access
|
return urls.access
|
||||||
|
|
||||||
def list_objs(self, token):
|
def list_objs(self, token):
|
||||||
@@ -148,6 +152,34 @@ class HfApi:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TqdmProgressFileReader:
|
||||||
|
"""
|
||||||
|
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`)
|
||||||
|
and override `f.read()` so as to display a tqdm progress bar.
|
||||||
|
|
||||||
|
see github.com/huggingface/transformers/pull/2078#discussion_r354739608
|
||||||
|
for implementation details.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
f # type: io.BufferedReader
|
||||||
|
):
|
||||||
|
self.f = f
|
||||||
|
self.total_size = os.fstat(f.fileno()).st_size # type: int
|
||||||
|
self.pbar = tqdm(total=self.total_size, leave=False)
|
||||||
|
if six.PY3:
|
||||||
|
# does not work unless PY3
|
||||||
|
# no big deal as the CLI does not currently support PY2 anyways.
|
||||||
|
self.read = f.read
|
||||||
|
f.read = self._read
|
||||||
|
|
||||||
|
def _read(self, n=-1):
|
||||||
|
self.pbar.update(n)
|
||||||
|
return self.read(n)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.pbar.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HfFolder:
|
class HfFolder:
|
||||||
|
|||||||
Reference in New Issue
Block a user