Support for private models from huggingface.co (#9141)

* minor wording tweaks

* Create private model repo + exist_ok flag

* file_utils: `use_auth_token`

* Update src/transformers/file_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Propagate doc from @sgugger

Co-Authored-By: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Julien Chaumond
2020-12-16 16:09:57 +01:00
committed by GitHub
parent c69d19faa8
commit fb650df859
8 changed files with 77 additions and 9 deletions

View File

@@ -16,6 +16,7 @@ Utilities for working with the local dataset cache. Parts of this file is adapte
https://github.com/allenai/allennlp.
"""
import copy
import fnmatch
import io
import json
@@ -42,6 +43,7 @@ import requests
from filelock import FileLock
from . import __version__
from .hf_api import HfFolder
from .utils import logging
@@ -1024,6 +1026,7 @@ def cached_path(
user_agent: Union[Dict, str, None] = None,
extract_compressed_file=False,
force_extract=False,
use_auth_token: Union[bool, str, None] = None,
local_files_only=False,
) -> Optional[str]:
"""
@@ -1036,6 +1039,8 @@ def cached_path(
force_download: if True, re-download the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletely received file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
will get token from ~/.huggingface.
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
file in a folder along the archive.
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
@@ -1063,6 +1068,7 @@ def cached_path(
proxies=proxies,
resume_download=resume_download,
user_agent=user_agent,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
elif os.path.exists(url_or_filename):
@@ -1125,11 +1131,11 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
"""
Donwload remote file. Do not gobble up errors.
"""
headers = {"user-agent": http_user_agent(user_agent)}
headers = copy.deepcopy(headers)
if resume_size > 0:
headers["Range"] = "bytes=%d-" % (resume_size,)
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
@@ -1159,6 +1165,7 @@ def get_from_cache(
etag_timeout=10,
resume_download=False,
user_agent: Union[Dict, str, None] = None,
use_auth_token: Union[bool, str, None] = None,
local_files_only=False,
) -> Optional[str]:
"""
@@ -1178,11 +1185,19 @@ def get_from_cache(
os.makedirs(cache_dir, exist_ok=True)
headers = {"user-agent": http_user_agent(user_agent)}
if isinstance(use_auth_token, str):
headers["authorization"] = "Bearer {}".format(use_auth_token)
elif use_auth_token:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
headers["authorization"] = "Bearer {}".format(token)
url_to_download = url
etag = None
if not local_files_only:
try:
headers = {"user-agent": http_user_agent(user_agent)}
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
r.raise_for_status()
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
@@ -1272,7 +1287,7 @@ def get_from_cache(
with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
logger.info("storing %s in cache at %s", url, cache_path)
os.replace(temp_file.name, cache_path)