add force_download option to from_pretrained methods

This commit is contained in:
thomwolf
2019-08-20 10:56:12 +02:00
parent c589862b78
commit fecaed0ed4
3 changed files with 24 additions and 8 deletions

View File

@@ -93,12 +93,15 @@ def filename_to_url(filename, cache_dir=None):
return url, etag
def cached_path(url_or_filename, cache_dir=None):
def cached_path(url_or_filename, cache_dir=None, force_download=False):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
"""
if cache_dir is None:
cache_dir = PYTORCH_TRANSFORMERS_CACHE
@@ -111,7 +114,7 @@ def cached_path(url_or_filename, cache_dir=None):
if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
@@ -184,7 +187,7 @@ def http_get(url, temp_file):
progress.close()
def get_from_cache(url, cache_dir=None):
def get_from_cache(url, cache_dir=None, force_download=False):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
@@ -227,11 +230,11 @@ def get_from_cache(url, cache_dir=None):
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
if not os.path.exists(cache_path):
if not os.path.exists(cache_path) or force_download:
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):