Tokenizer.from_pretrained: fetch all possible files remotely

This commit is contained in:
Julien Chaumond
2020-01-15 21:07:26 +00:00
parent 99f9243de5
commit 23a2cea8cb
4 changed files with 109 additions and 81 deletions

View File

@@ -14,6 +14,7 @@ import tempfile
from contextlib import contextmanager
from functools import partial, wraps
from hashlib import sha256
from typing import Optional
from urllib.parse import urlparse
import boto3
@@ -122,7 +123,7 @@ def is_remote_url(url_or_filename):
return parsed.scheme in ("http", "https", "s3")
def hf_bucket_url(identifier, postfix=None, cdn=False):
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
if postfix is None:
return "/".join((endpoint, identifier))
@@ -182,7 +183,7 @@ def filename_to_url(filename, cache_dir=None):
def cached_path(
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
):
) -> Optional[str]:
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
@@ -193,6 +194,10 @@ def cached_path(
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
Local path (string) otherwise
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
@@ -306,10 +311,14 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
def get_from_cache(
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
):
) -> Optional[str]:
"""
Given a URL, look for the corresponding dataset in the local cache.
Given a URL, look for the corresponding file in the local cache.
If it's not there, download it. Then return the path to the cached file.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
Local path (string) otherwise
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
@@ -336,16 +345,25 @@ def get_from_cache(
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# If we don't have a connection (etag is None) and can't identify the file
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
matching_files = [
file
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
if not file.endswith(".json") and not file.endswith(".lock")
]
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
if etag is None:
if os.path.exists(cache_path):
return cache_path
else:
matching_files = [
file
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
if not file.endswith(".json") and not file.endswith(".lock")
]
if len(matching_files) > 0:
return os.path.join(cache_dir, matching_files[-1])
else:
return None
# From now on, etag is not None.
if os.path.exists(cache_path) and not force_download:
return cache_path
# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + ".lock"
@@ -368,29 +386,26 @@ def get_from_cache(
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
resume_size = 0
if etag is not None and (not os.path.exists(cache_path) or force_download):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info(
"%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name
)
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
if resume_download:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
# GET file object
if url.startswith("s3://"):
if resume_download:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
logger.info("storing %s in cache at %s", url, cache_path)
os.rename(temp_file.name, cache_path)
logger.info("storing %s in cache at %s", url, cache_path)
os.rename(temp_file.name, cache_path)
logger.info("creating metadata file for %s", cache_path)
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
json.dump(meta, meta_file)
logger.info("creating metadata file for %s", cache_path)
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
json.dump(meta, meta_file)
return cache_path