Add support for resumable downloads for HTTP protocol.

This commit is contained in:
Sergey Mironov
2019-10-24 18:15:55 +03:00
parent 0e64fec1ab
commit 0e4cc050d6
9 changed files with 87 additions and 15 deletions

View File

@@ -246,6 +246,9 @@ class PreTrainedModel(nn.Module):
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
@@ -275,6 +278,7 @@ class PreTrainedModel(nn.Module):
cache_dir = kwargs.pop('cache_dir', None)
from_tf = kwargs.pop('from_tf', False)
force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None)
output_loading_info = kwargs.pop('output_loading_info', False)
@@ -284,6 +288,7 @@ class PreTrainedModel(nn.Module):
pretrained_model_name_or_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
**kwargs
)
else:
@@ -315,7 +320,8 @@ class PreTrainedModel(nn.Module):
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, resume_download=resume_download)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained weights.".format(