From a4c9338b83ba612b5f5aec645f375d048d9a7647 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 20 Dec 2019 20:56:59 +0100 Subject: [PATCH] Prevent parallel downloads of the same file with a lock. Since the file is written to the filesystem, a filesystem lock is the way to go here. Add a dependency on the third-party filelock library to get cross-platform functionality. --- setup.py | 1 + transformers/file_utils.py | 89 +++++++++++++++++++++----------------- 2 files changed, 50 insertions(+), 40 deletions(-) diff --git a/setup.py b/setup.py index cd64a6ce90..fe2e1526bf 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ setup( "tests.*", "tests"]), install_requires=['numpy', 'boto3', + 'filelock', 'requests', 'tqdm', 'regex != 2019.12.17', diff --git a/transformers/file_utils.py b/transformers/file_utils.py index 61ff1d00bc..ec925c6160 100644 --- a/transformers/file_utils.py +++ b/transformers/file_utils.py @@ -24,6 +24,8 @@ from tqdm.auto import tqdm from contextlib import contextmanager from . import __version__ +from filelock import FileLock + logger = logging.getLogger(__name__) # pylint: disable=invalid-name try: @@ -333,53 +335,60 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag # If we don't have a connection (etag is None) and can't identify the file # try to get the last downloaded one if not os.path.exists(cache_path) and etag is None: - matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') - matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) + matching_files = [ + file + for file in fnmatch.filter(os.listdir(cache_dir), filename + '.*') + if not file.endswith('.json') and not file.endswith('.lock') + ] if matching_files: cache_path = os.path.join(cache_dir, matching_files[-1]) - if resume_download: - incomplete_path = cache_path + '.incomplete' - @contextmanager - def _resumable_file_manager(): - with open(incomplete_path,'a+b') as f: - yield f - temp_file_manager = _resumable_file_manager - if os.path.exists(incomplete_path): - resume_size = os.stat(incomplete_path).st_size - else: - resume_size = 0 - else: - temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) - resume_size = 0 + # Prevent parallel downloads of the same file with a lock. + lock_path = cache_path + '.lock' + with FileLock(lock_path): - if etag is not None and (not os.path.exists(cache_path) or force_download): - # Download to temporary file, then copy to cache dir once finished. - # Otherwise you get corrupt cache entries if the download gets interrupted. - with temp_file_manager() as temp_file: - logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) - - # GET file object - if url.startswith("s3://"): - if resume_download: - logger.warn('Warning: resumable downloads are not implemented for "s3://" urls') - s3_get(url, temp_file, proxies=proxies) + if resume_download: + incomplete_path = cache_path + '.incomplete' + @contextmanager + def _resumable_file_manager(): + with open(incomplete_path,'a+b') as f: + yield f + temp_file_manager = _resumable_file_manager + if os.path.exists(incomplete_path): + resume_size = os.stat(incomplete_path).st_size else: - http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) + resume_size = 0 + else: + temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) + resume_size = 0 - # we are copying the file before closing it, so flush to avoid truncation - temp_file.flush() + if etag is not None and (not os.path.exists(cache_path) or force_download): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with temp_file_manager() as temp_file: + logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) - logger.info("storing %s in cache at %s", url, cache_path) - os.rename(temp_file.name, cache_path) + # GET file object + if url.startswith("s3://"): + if resume_download: + logger.warn('Warning: resumable downloads are not implemented for "s3://" urls') + s3_get(url, temp_file, proxies=proxies) + else: + http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) - logger.info("creating metadata file for %s", cache_path) - meta = {'url': url, 'etag': etag} - meta_path = cache_path + '.json' - with open(meta_path, 'w') as meta_file: - output_string = json.dumps(meta) - if sys.version_info[0] == 2 and isinstance(output_string, str): - output_string = unicode(output_string, 'utf-8') # The beauty of python 2 - meta_file.write(output_string) + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + + logger.info("storing %s in cache at %s", url, cache_path) + os.rename(temp_file.name, cache_path) + + logger.info("creating metadata file for %s", cache_path) + meta = {'url': url, 'etag': etag} + meta_path = cache_path + '.json' + with open(meta_path, 'w') as meta_file: + output_string = json.dumps(meta) + if sys.version_info[0] == 2 and isinstance(output_string, str): + output_string = unicode(output_string, 'utf-8') # The beauty of python 2 + meta_file.write(output_string) return cache_path