Add support for resumable downloads for HTTP protocol.
This commit is contained in:
@@ -92,6 +92,9 @@ class AutoConfig(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
|
||||
@@ -93,6 +93,9 @@ class PretrainedConfig(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
@@ -119,6 +122,7 @@ class PretrainedConfig(object):
|
||||
"""
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
resume_download = kwargs.pop('resume_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
||||
|
||||
@@ -130,7 +134,8 @@ class PretrainedConfig(object):
|
||||
config_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download,
|
||||
proxies=proxies, resume_download=resume_download)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
|
||||
|
||||
@@ -22,6 +22,7 @@ from botocore.config import Config
|
||||
from botocore.exceptions import ClientError
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -152,7 +153,7 @@ def filename_to_url(filename, cache_dir=None):
|
||||
return url, etag
|
||||
|
||||
|
||||
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None):
|
||||
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False):
|
||||
"""
|
||||
Given something that might be a URL (or might be a local path),
|
||||
determine which. If it's a URL, download the file and cache it, and
|
||||
@@ -161,6 +162,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
|
||||
Args:
|
||||
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
|
||||
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
|
||||
resume_download: if True, resume the download if incompletly recieved file is found.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
@@ -173,7 +175,9 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
|
||||
|
||||
if parsed.scheme in ('http', 'https', 's3'):
|
||||
# URL, so get it from the cache (downloading if necessary)
|
||||
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
return get_from_cache(url_or_filename, cache_dir=cache_dir,
|
||||
force_download=force_download, proxies=proxies,
|
||||
resume_download=resume_download)
|
||||
elif os.path.exists(url_or_filename):
|
||||
# File, and it exists.
|
||||
return url_or_filename
|
||||
@@ -234,19 +238,22 @@ def s3_get(url, temp_file, proxies=None):
|
||||
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
||||
|
||||
|
||||
def http_get(url, temp_file, proxies=None):
|
||||
req = requests.get(url, stream=True, proxies=proxies)
|
||||
content_length = req.headers.get('Content-Length')
|
||||
total = int(content_length) if content_length is not None else None
|
||||
progress = tqdm(unit="B", total=total)
|
||||
for chunk in req.iter_content(chunk_size=1024):
|
||||
def http_get(url, temp_file, proxies=None, resume_size=0):
|
||||
headers={'Range':'bytes=%d-'%(resume_size,)} if resume_size > 0 else None
|
||||
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||
if response.status_code == 416: # Range not satisfiable
|
||||
return
|
||||
content_length = response.headers.get('Content-Length')
|
||||
total = resume_size + int(content_length) if content_length is not None else None
|
||||
progress = tqdm(unit="B", total=total, initial=resume_size)
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
progress.close()
|
||||
|
||||
|
||||
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10):
|
||||
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False):
|
||||
"""
|
||||
Given a URL, look for the corresponding dataset in the local cache.
|
||||
If it's not there, download it. Then return the path to the cached file.
|
||||
@@ -289,17 +296,35 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
|
||||
if matching_files:
|
||||
cache_path = os.path.join(cache_dir, matching_files[-1])
|
||||
|
||||
if resume_download:
|
||||
incomplete_path = cache_path + '.incomplete'
|
||||
@contextmanager
|
||||
def _resumable_file_manager():
|
||||
with open(incomplete_path,'a+b') as f:
|
||||
yield f
|
||||
os.remove(incomplete_path)
|
||||
temp_file_manager = _resumable_file_manager
|
||||
if os.path.exists(incomplete_path):
|
||||
resume_size = os.stat(incomplete_path).st_size
|
||||
else:
|
||||
resume_size = 0
|
||||
else:
|
||||
temp_file_manager = tempfile.NamedTemporaryFile
|
||||
resume_size = 0
|
||||
|
||||
if not os.path.exists(cache_path) or force_download:
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
with tempfile.NamedTemporaryFile() as temp_file:
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||
|
||||
# GET file object
|
||||
if url.startswith("s3://"):
|
||||
if resume_download:
|
||||
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
|
||||
s3_get(url, temp_file, proxies=proxies)
|
||||
else:
|
||||
http_get(url, temp_file, proxies=proxies)
|
||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size)
|
||||
|
||||
# we are copying the file before closing it, so flush to avoid truncation
|
||||
temp_file.flush()
|
||||
|
||||
@@ -112,6 +112,9 @@ class AutoModel(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
@@ -237,6 +240,8 @@ class AutoModelWithLMHead(object):
|
||||
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
@@ -357,6 +362,9 @@ class AutoModelForSequenceClassification(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
|
||||
@@ -109,6 +109,9 @@ class TFAutoModel(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
@@ -237,6 +240,9 @@ class TFAutoModelWithLMHead(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
@@ -360,6 +366,9 @@ class TFAutoModelForSequenceClassification(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
@@ -472,6 +481,9 @@ class TFAutoModelForQuestionAnswering(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
|
||||
@@ -176,6 +176,9 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
@@ -201,6 +204,7 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
from_pt = kwargs.pop('from_pt', False)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
resume_download = kwargs.pop('resume_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
|
||||
# Load config
|
||||
@@ -209,6 +213,7 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
pretrained_model_name_or_path, *model_args,
|
||||
cache_dir=cache_dir, return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
@@ -236,7 +241,8 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
|
||||
resume_download=resume_download, proxies=proxies)
|
||||
except EnvironmentError as e:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
logger.error(
|
||||
|
||||
@@ -246,6 +246,9 @@ class PreTrainedModel(nn.Module):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
@@ -275,6 +278,7 @@ class PreTrainedModel(nn.Module):
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
from_tf = kwargs.pop('from_tf', False)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
resume_download = kwargs.pop('resume_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||
|
||||
@@ -284,6 +288,7 @@ class PreTrainedModel(nn.Module):
|
||||
pretrained_model_name_or_path, *model_args,
|
||||
cache_dir=cache_dir, return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
@@ -315,7 +320,8 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
|
||||
proxies=proxies, resume_download=resume_download)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
|
||||
|
||||
@@ -87,6 +87,9 @@ class AutoTokenizer(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the vocabulary files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
|
||||
@@ -251,6 +251,9 @@ class PreTrainedTokenizer(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the vocabulary files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
@@ -286,6 +289,7 @@ class PreTrainedTokenizer(object):
|
||||
def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
resume_download = kwargs.pop('resume_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
|
||||
s3_models = list(cls.max_model_input_sizes.keys())
|
||||
@@ -352,7 +356,7 @@ class PreTrainedTokenizer(object):
|
||||
if file_path is None:
|
||||
resolved_vocab_files[file_id] = None
|
||||
else:
|
||||
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in s3_models:
|
||||
msg = "Couldn't reach server at '{}' to download vocabulary files."
|
||||
|
||||
Reference in New Issue
Block a user