Model versioning (#8324)

* fix typo

* rm use_cdn & references, and implement new hf_bucket_url

* I'm pretty sure we don't need to `read` this file

* same here

* [BIG] file_utils.networking: do not gobble up errors anymore

* Fix CI 😇

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Tiny doc tweak

* Add doc + pass kwarg everywhere

* Add more tests and explain

cc @sshleifer let me know if better

Co-Authored-By: Sam Shleifer <sshleifer@gmail.com>

* Also implement revision in pipelines

In the case where we're passing a task name or a string model identifier

* Fix CI 😇

* Fix CI

* [hf_api] new methods + command line implem

* make style

* Final endpoints post-migration

* Fix post-migration

* Py3.6 compat

cc @stefan-it

Thank you @stas00

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Julien Chaumond
2020-11-10 13:11:02 +01:00
committed by GitHub
parent 4185b115d4
commit 70f622fab4
23 changed files with 472 additions and 210 deletions

View File

@@ -4,6 +4,7 @@ https://github.com/allenai/allennlp Copyright by the AllenNLP authors.
"""
import fnmatch
import io
import json
import os
import re
@@ -17,7 +18,7 @@ from dataclasses import fields
from functools import partial, wraps
from hashlib import sha256
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
@@ -217,6 +218,8 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
PRESET_MIRROR_DICT = {
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
@@ -825,34 +828,37 @@ def is_remote_url(url_or_filename):
return parsed.scheme in ("http", "https")
def hf_bucket_url(model_id: str, filename: str, use_cdn=True, mirror=None) -> str:
def hf_bucket_url(model_id: str, filename: str, revision: Optional[str] = None, mirror=None) -> str:
"""
Resolve a model identifier, and a file name, to a HF-hosted url on either S3 or Cloudfront (a Content Delivery
Network, or CDN).
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
to Cloudfront (a Content Delivery Network, or CDN) for large files.
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
bandwidth costs). However, it is more aggressively cached by default, so may not always reflect the latest changes
to the underlying file (default TTL is 24 hours).
bandwidth costs).
In terms of client-side caching from this library, even though Cloudfront relays the ETags from S3, using one or
the other (or switching from one to the other) will affect caching: cached files are not shared between the two
because the cached file's name contains a hash of the url.
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
can't ever be stale.
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
"""
endpoint = (
PRESET_MIRROR_DICT.get(mirror, mirror)
if mirror
else CLOUDFRONT_DISTRIB_PREFIX
if use_cdn
else S3_BUCKET_PREFIX
)
legacy_format = "/" not in model_id
if legacy_format:
return f"{endpoint}/{model_id}-{filename}"
else:
return f"{endpoint}/{model_id}/{filename}"
if mirror:
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
legacy_format = "/" not in model_id
if legacy_format:
return f"{endpoint}/{model_id}-{filename}"
else:
return f"{endpoint}/{model_id}/{filename}"
if revision is None:
revision = "main"
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
def url_to_filename(url, etag=None):
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
"""
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
@@ -860,13 +866,11 @@ def url_to_filename(url, etag=None):
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
"""
url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
filename = sha256(url_bytes).hexdigest()
if etag:
etag_bytes = etag.encode("utf-8")
etag_hash = sha256(etag_bytes)
filename += "." + etag_hash.hexdigest()
filename += "." + sha256(etag_bytes).hexdigest()
if url.endswith(".h5"):
filename += ".h5"
@@ -927,8 +931,10 @@ def cached_path(
re-extract the archive and override the folder where it was extracted.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). Local path (string)
otherwise
Local path (string) of file or if networking is off, last version of file cached on disk.
Raises:
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
@@ -992,7 +998,10 @@ def cached_path(
return output_path
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
"""
Formats a user-agent string with basic info about a request.
"""
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if is_torch_available():
ua += "; torch/{}".format(torch.__version__)
@@ -1002,13 +1011,19 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
headers = {"user-agent": ua}
return ua
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
"""
Donwload remote file. Do not gobble up errors.
"""
headers = {"user-agent": http_user_agent(user_agent)}
if resume_size > 0:
headers["Range"] = "bytes=%d-" % (resume_size,)
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
if response.status_code == 416: # Range not satisfiable
return
content_length = response.headers.get("Content-Length")
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
r.raise_for_status()
content_length = r.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(
unit="B",
@@ -1018,7 +1033,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
desc="Downloading",
disable=bool(logging.get_verbosity() == logging.NOTSET),
)
for chunk in response.iter_content(chunk_size=1024):
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
@@ -1026,7 +1041,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
def get_from_cache(
url,
url: str,
cache_dir=None,
force_download=False,
proxies=None,
@@ -1040,8 +1055,10 @@ def get_from_cache(
path to the cached file.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). Local path (string)
otherwise
Local path (string) of file or if networking is off, last version of file cached on disk.
Raises:
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
@@ -1050,13 +1067,28 @@ def get_from_cache(
os.makedirs(cache_dir, exist_ok=True)
url_to_download = url
etag = None
if not local_files_only:
try:
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
if response.status_code == 200:
etag = response.headers.get("ETag")
except (EnvironmentError, requests.exceptions.Timeout):
headers = {"user-agent": http_user_agent(user_agent)}
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
r.raise_for_status()
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
# We favor a custom header indicating the etag of the linked resource, and
# we fallback to the regular etag header.
# If we don't have any of those, raise an error.
if etag is None:
raise OSError(
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
)
# In case of a redirect,
# save an extra redirect on the request.get call,
# and ensure we download the exact atomic version even if it changed
# between the HEAD and the GET (unlikely, but hey).
if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"]
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
# etag is already None
pass
@@ -1065,7 +1097,7 @@ def get_from_cache(
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
# etag is None == we don't have a connection or we passed local_files_only.
# try to get the last downloaded one
if etag is None:
if os.path.exists(cache_path):
@@ -1088,7 +1120,11 @@ def get_from_cache(
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False."
)
return None
else:
raise ValueError(
"Connection error, and we cannot find the requested files in the cached path."
" Please try again or make sure your Internet connection is on."
)
# From now on, etag is not None.
if os.path.exists(cache_path) and not force_download:
@@ -1107,8 +1143,8 @@ def get_from_cache(
incomplete_path = cache_path + ".incomplete"
@contextmanager
def _resumable_file_manager():
with open(incomplete_path, "a+b") as f:
def _resumable_file_manager() -> "io.BufferedWriter":
with open(incomplete_path, "ab") as f:
yield f
temp_file_manager = _resumable_file_manager
@@ -1117,7 +1153,7 @@ def get_from_cache(
else:
resume_size = 0
else:
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
resume_size = 0
# Download to temporary file, then copy to cache dir once finished.
@@ -1125,7 +1161,7 @@ def get_from_cache(
with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
logger.info("storing %s in cache at %s", url, cache_path)
os.replace(temp_file.name, cache_path)