Model versioning (#8324)
* fix typo * rm use_cdn & references, and implement new hf_bucket_url * I'm pretty sure we don't need to `read` this file * same here * [BIG] file_utils.networking: do not gobble up errors anymore * Fix CI 😇 * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Tiny doc tweak * Add doc + pass kwarg everywhere * Add more tests and explain cc @sshleifer let me know if better Co-Authored-By: Sam Shleifer <sshleifer@gmail.com> * Also implement revision in pipelines In the case where we're passing a task name or a string model identifier * Fix CI 😇 * Fix CI * [hf_api] new methods + command line implem * make style * Final endpoints post-migration * Fix post-migration * Py3.6 compat cc @stefan-it Thank you @stas00 Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@ https://github.com/allenai/allennlp Copyright by the AllenNLP authors.
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -17,7 +18,7 @@ from dataclasses import fields
|
||||
from functools import partial, wraps
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
@@ -217,6 +218,8 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
|
||||
|
||||
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
|
||||
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
|
||||
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
|
||||
|
||||
PRESET_MIRROR_DICT = {
|
||||
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
|
||||
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
|
||||
@@ -825,34 +828,37 @@ def is_remote_url(url_or_filename):
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
|
||||
def hf_bucket_url(model_id: str, filename: str, use_cdn=True, mirror=None) -> str:
|
||||
def hf_bucket_url(model_id: str, filename: str, revision: Optional[str] = None, mirror=None) -> str:
|
||||
"""
|
||||
Resolve a model identifier, and a file name, to a HF-hosted url on either S3 or Cloudfront (a Content Delivery
|
||||
Network, or CDN).
|
||||
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
|
||||
to Cloudfront (a Content Delivery Network, or CDN) for large files.
|
||||
|
||||
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
|
||||
bandwidth costs). However, it is more aggressively cached by default, so may not always reflect the latest changes
|
||||
to the underlying file (default TTL is 24 hours).
|
||||
bandwidth costs).
|
||||
|
||||
In terms of client-side caching from this library, even though Cloudfront relays the ETags from S3, using one or
|
||||
the other (or switching from one to the other) will affect caching: cached files are not shared between the two
|
||||
because the cached file's name contains a hash of the url.
|
||||
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
|
||||
because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
|
||||
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
|
||||
can't ever be stale.
|
||||
|
||||
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
|
||||
its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
|
||||
are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
|
||||
"""
|
||||
endpoint = (
|
||||
PRESET_MIRROR_DICT.get(mirror, mirror)
|
||||
if mirror
|
||||
else CLOUDFRONT_DISTRIB_PREFIX
|
||||
if use_cdn
|
||||
else S3_BUCKET_PREFIX
|
||||
)
|
||||
legacy_format = "/" not in model_id
|
||||
if legacy_format:
|
||||
return f"{endpoint}/{model_id}-{filename}"
|
||||
else:
|
||||
return f"{endpoint}/{model_id}/{filename}"
|
||||
if mirror:
|
||||
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
|
||||
legacy_format = "/" not in model_id
|
||||
if legacy_format:
|
||||
return f"{endpoint}/{model_id}-{filename}"
|
||||
else:
|
||||
return f"{endpoint}/{model_id}/{filename}"
|
||||
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
|
||||
|
||||
|
||||
def url_to_filename(url, etag=None):
|
||||
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
|
||||
"""
|
||||
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
|
||||
delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
|
||||
@@ -860,13 +866,11 @@ def url_to_filename(url, etag=None):
|
||||
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
|
||||
"""
|
||||
url_bytes = url.encode("utf-8")
|
||||
url_hash = sha256(url_bytes)
|
||||
filename = url_hash.hexdigest()
|
||||
filename = sha256(url_bytes).hexdigest()
|
||||
|
||||
if etag:
|
||||
etag_bytes = etag.encode("utf-8")
|
||||
etag_hash = sha256(etag_bytes)
|
||||
filename += "." + etag_hash.hexdigest()
|
||||
filename += "." + sha256(etag_bytes).hexdigest()
|
||||
|
||||
if url.endswith(".h5"):
|
||||
filename += ".h5"
|
||||
@@ -927,8 +931,10 @@ def cached_path(
|
||||
re-extract the archive and override the folder where it was extracted.
|
||||
|
||||
Return:
|
||||
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). Local path (string)
|
||||
otherwise
|
||||
Local path (string) of file or if networking is off, last version of file cached on disk.
|
||||
|
||||
Raises:
|
||||
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
@@ -992,7 +998,10 @@ def cached_path(
|
||||
return output_path
|
||||
|
||||
|
||||
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
|
||||
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
"""
|
||||
Formats a user-agent string with basic info about a request.
|
||||
"""
|
||||
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
||||
if is_torch_available():
|
||||
ua += "; torch/{}".format(torch.__version__)
|
||||
@@ -1002,13 +1011,19 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
|
||||
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua += "; " + user_agent
|
||||
headers = {"user-agent": ua}
|
||||
return ua
|
||||
|
||||
|
||||
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
|
||||
"""
|
||||
Donwload remote file. Do not gobble up errors.
|
||||
"""
|
||||
headers = {"user-agent": http_user_agent(user_agent)}
|
||||
if resume_size > 0:
|
||||
headers["Range"] = "bytes=%d-" % (resume_size,)
|
||||
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||
if response.status_code == 416: # Range not satisfiable
|
||||
return
|
||||
content_length = response.headers.get("Content-Length")
|
||||
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||
r.raise_for_status()
|
||||
content_length = r.headers.get("Content-Length")
|
||||
total = resume_size + int(content_length) if content_length is not None else None
|
||||
progress = tqdm(
|
||||
unit="B",
|
||||
@@ -1018,7 +1033,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
|
||||
desc="Downloading",
|
||||
disable=bool(logging.get_verbosity() == logging.NOTSET),
|
||||
)
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
@@ -1026,7 +1041,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
|
||||
|
||||
|
||||
def get_from_cache(
|
||||
url,
|
||||
url: str,
|
||||
cache_dir=None,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
@@ -1040,8 +1055,10 @@ def get_from_cache(
|
||||
path to the cached file.
|
||||
|
||||
Return:
|
||||
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). Local path (string)
|
||||
otherwise
|
||||
Local path (string) of file or if networking is off, last version of file cached on disk.
|
||||
|
||||
Raises:
|
||||
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
@@ -1050,13 +1067,28 @@ def get_from_cache(
|
||||
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
url_to_download = url
|
||||
etag = None
|
||||
if not local_files_only:
|
||||
try:
|
||||
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
||||
if response.status_code == 200:
|
||||
etag = response.headers.get("ETag")
|
||||
except (EnvironmentError, requests.exceptions.Timeout):
|
||||
headers = {"user-agent": http_user_agent(user_agent)}
|
||||
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
||||
r.raise_for_status()
|
||||
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
||||
# We favor a custom header indicating the etag of the linked resource, and
|
||||
# we fallback to the regular etag header.
|
||||
# If we don't have any of those, raise an error.
|
||||
if etag is None:
|
||||
raise OSError(
|
||||
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
||||
)
|
||||
# In case of a redirect,
|
||||
# save an extra redirect on the request.get call,
|
||||
# and ensure we download the exact atomic version even if it changed
|
||||
# between the HEAD and the GET (unlikely, but hey).
|
||||
if 300 <= r.status_code <= 399:
|
||||
url_to_download = r.headers["Location"]
|
||||
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
||||
# etag is already None
|
||||
pass
|
||||
|
||||
@@ -1065,7 +1097,7 @@ def get_from_cache(
|
||||
# get cache path to put the file
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
|
||||
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
|
||||
# etag is None == we don't have a connection or we passed local_files_only.
|
||||
# try to get the last downloaded one
|
||||
if etag is None:
|
||||
if os.path.exists(cache_path):
|
||||
@@ -1088,7 +1120,11 @@ def get_from_cache(
|
||||
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
||||
" to False."
|
||||
)
|
||||
return None
|
||||
else:
|
||||
raise ValueError(
|
||||
"Connection error, and we cannot find the requested files in the cached path."
|
||||
" Please try again or make sure your Internet connection is on."
|
||||
)
|
||||
|
||||
# From now on, etag is not None.
|
||||
if os.path.exists(cache_path) and not force_download:
|
||||
@@ -1107,8 +1143,8 @@ def get_from_cache(
|
||||
incomplete_path = cache_path + ".incomplete"
|
||||
|
||||
@contextmanager
|
||||
def _resumable_file_manager():
|
||||
with open(incomplete_path, "a+b") as f:
|
||||
def _resumable_file_manager() -> "io.BufferedWriter":
|
||||
with open(incomplete_path, "ab") as f:
|
||||
yield f
|
||||
|
||||
temp_file_manager = _resumable_file_manager
|
||||
@@ -1117,7 +1153,7 @@ def get_from_cache(
|
||||
else:
|
||||
resume_size = 0
|
||||
else:
|
||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
|
||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
|
||||
resume_size = 0
|
||||
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
@@ -1125,7 +1161,7 @@ def get_from_cache(
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||
|
||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
|
||||
logger.info("storing %s in cache at %s", url, cache_path)
|
||||
os.replace(temp_file.name, cache_path)
|
||||
|
||||
Reference in New Issue
Block a user