Tokenizer.from_pretrained: fetch all possible files remotely
This commit is contained in:
@@ -14,6 +14,7 @@ import tempfile
|
||||
from contextlib import contextmanager
|
||||
from functools import partial, wraps
|
||||
from hashlib import sha256
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
@@ -122,7 +123,7 @@ def is_remote_url(url_or_filename):
|
||||
return parsed.scheme in ("http", "https", "s3")
|
||||
|
||||
|
||||
def hf_bucket_url(identifier, postfix=None, cdn=False):
|
||||
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
|
||||
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
|
||||
if postfix is None:
|
||||
return "/".join((endpoint, identifier))
|
||||
@@ -182,7 +183,7 @@ def filename_to_url(filename, cache_dir=None):
|
||||
|
||||
def cached_path(
|
||||
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
|
||||
):
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Given something that might be a URL (or might be a local path),
|
||||
determine which. If it's a URL, download the file and cache it, and
|
||||
@@ -193,6 +194,10 @@ def cached_path(
|
||||
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
|
||||
resume_download: if True, resume the download if incompletly recieved file is found.
|
||||
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
|
||||
|
||||
Return:
|
||||
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||
Local path (string) otherwise
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
@@ -306,10 +311,14 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
||||
|
||||
def get_from_cache(
|
||||
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
|
||||
):
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Given a URL, look for the corresponding dataset in the local cache.
|
||||
Given a URL, look for the corresponding file in the local cache.
|
||||
If it's not there, download it. Then return the path to the cached file.
|
||||
|
||||
Return:
|
||||
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||
Local path (string) otherwise
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
@@ -336,16 +345,25 @@ def get_from_cache(
|
||||
# get cache path to put the file
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
|
||||
# If we don't have a connection (etag is None) and can't identify the file
|
||||
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
|
||||
# try to get the last downloaded one
|
||||
if not os.path.exists(cache_path) and etag is None:
|
||||
matching_files = [
|
||||
file
|
||||
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
||||
if not file.endswith(".json") and not file.endswith(".lock")
|
||||
]
|
||||
if matching_files:
|
||||
cache_path = os.path.join(cache_dir, matching_files[-1])
|
||||
if etag is None:
|
||||
if os.path.exists(cache_path):
|
||||
return cache_path
|
||||
else:
|
||||
matching_files = [
|
||||
file
|
||||
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
||||
if not file.endswith(".json") and not file.endswith(".lock")
|
||||
]
|
||||
if len(matching_files) > 0:
|
||||
return os.path.join(cache_dir, matching_files[-1])
|
||||
else:
|
||||
return None
|
||||
|
||||
# From now on, etag is not None.
|
||||
if os.path.exists(cache_path) and not force_download:
|
||||
return cache_path
|
||||
|
||||
# Prevent parallel downloads of the same file with a lock.
|
||||
lock_path = cache_path + ".lock"
|
||||
@@ -368,29 +386,26 @@ def get_from_cache(
|
||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
|
||||
resume_size = 0
|
||||
|
||||
if etag is not None and (not os.path.exists(cache_path) or force_download):
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info(
|
||||
"%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name
|
||||
)
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||
|
||||
# GET file object
|
||||
if url.startswith("s3://"):
|
||||
if resume_download:
|
||||
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
|
||||
s3_get(url, temp_file, proxies=proxies)
|
||||
else:
|
||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
# GET file object
|
||||
if url.startswith("s3://"):
|
||||
if resume_download:
|
||||
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
|
||||
s3_get(url, temp_file, proxies=proxies)
|
||||
else:
|
||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
|
||||
logger.info("storing %s in cache at %s", url, cache_path)
|
||||
os.rename(temp_file.name, cache_path)
|
||||
logger.info("storing %s in cache at %s", url, cache_path)
|
||||
os.rename(temp_file.name, cache_path)
|
||||
|
||||
logger.info("creating metadata file for %s", cache_path)
|
||||
meta = {"url": url, "etag": etag}
|
||||
meta_path = cache_path + ".json"
|
||||
with open(meta_path, "w") as meta_file:
|
||||
json.dump(meta, meta_file)
|
||||
logger.info("creating metadata file for %s", cache_path)
|
||||
meta = {"url": url, "etag": etag}
|
||||
meta_path = cache_path + ".json"
|
||||
with open(meta_path, "w") as meta_file:
|
||||
json.dump(meta, meta_file)
|
||||
|
||||
return cache_path
|
||||
|
||||
Reference in New Issue
Block a user