Add support for resumable downloads for HTTP protocol.

This commit is contained in:
Sergey Mironov
2019-10-24 18:15:55 +03:00
parent 0e64fec1ab
commit 0e4cc050d6
9 changed files with 87 additions and 15 deletions

View File

@@ -93,6 +93,9 @@ class PretrainedConfig(object):
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
@@ -119,6 +122,7 @@ class PretrainedConfig(object):
"""
cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
@@ -130,7 +134,8 @@ class PretrainedConfig(object):
config_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, resume_download=resume_download)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(