From 7d99e05f76af043702cfdefd0faebab61c0d9886 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 7 Feb 2020 00:03:12 +0100 Subject: [PATCH] file_cache has options to extract archives --- src/transformers/file_utils.py | 51 ++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 8aafa95f43..93a19a6013 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -8,13 +8,16 @@ import fnmatch import json import logging import os +import shutil import sys +import tarfile import tempfile from contextlib import contextmanager from functools import partial, wraps from hashlib import sha256 from typing import Optional from urllib.parse import urlparse +from zipfile import ZipFile, is_zipfile import boto3 import requests @@ -203,7 +206,14 @@ def filename_to_url(filename, cache_dir=None): def cached_path( - url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None + url_or_filename, + cache_dir=None, + force_download=False, + proxies=None, + resume_download=False, + user_agent=None, + extract_compressed_file=False, + force_extract=False, ) -> Optional[str]: """ Given something that might be a URL (or might be a local path), @@ -215,6 +225,10 @@ def cached_path( force_download: if True, re-dowload the file even if it's already cached in the cache dir. resume_download: if True, resume the download if incompletly recieved file is found. user_agent: Optional string or dict that will be appended to the user-agent on remote requests. + extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed + file in a folder along the archive. + force_extract: if True when extract_compressed_file is True and the archive was already extracted, + re-extract the archive and overide the folder where it was extracted. Return: None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). @@ -229,7 +243,7 @@ def cached_path( if is_remote_url(url_or_filename): # URL, so get it from the cache (downloading if necessary) - return get_from_cache( + output_path = get_from_cache( url_or_filename, cache_dir=cache_dir, force_download=force_download, @@ -239,7 +253,7 @@ def cached_path( ) elif os.path.exists(url_or_filename): # File, and it exists. - return url_or_filename + output_path = url_or_filename elif urlparse(url_or_filename).scheme == "": # File, but it doesn't exist. raise EnvironmentError("file {} not found".format(url_or_filename)) @@ -247,6 +261,37 @@ def cached_path( # Something unknown raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + if extract_compressed_file: + if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): + return output_path + + # Path where we extract compressed archives + # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/" + output_dir, output_file = os.path.split(output_path) + output_extract_dir_name = output_file.replace(".", "-") + "-extracted" + output_path_extracted = os.path.join(output_dir, output_extract_dir_name) + + if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: + return output_path_extracted + + # Prevent parallel extractions + lock_path = output_path + ".lock" + with FileLock(lock_path): + shutil.rmtree(output_path_extracted, ignore_errors=True) + os.makedirs(output_path_extracted) + if is_zipfile(output_path): + with ZipFile(output_path, "r") as zip_file: + zip_file.extractall(output_path_extracted) + zip_file.close() + elif tarfile.is_tarfile(output_path): + tar_file = tarfile.open(output_path) + tar_file.extractall(output_path_extracted) + tar_file.close() + + return output_path_extracted + + return output_path + def split_s3_path(url): """Split a full s3 path into the bucket name and path."""