Support for private models from huggingface.co (#9141)
* minor wording tweaks * Create private model repo + exist_ok flag * file_utils: `use_auth_token` * Update src/transformers/file_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Propagate doc from @sgugger Co-Authored-By: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@ Utilities for working with the local dataset cache. Parts of this file is adapte
|
||||
https://github.com/allenai/allennlp.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import fnmatch
|
||||
import io
|
||||
import json
|
||||
@@ -42,6 +43,7 @@ import requests
|
||||
from filelock import FileLock
|
||||
|
||||
from . import __version__
|
||||
from .hf_api import HfFolder
|
||||
from .utils import logging
|
||||
|
||||
|
||||
@@ -1024,6 +1026,7 @@ def cached_path(
|
||||
user_agent: Union[Dict, str, None] = None,
|
||||
extract_compressed_file=False,
|
||||
force_extract=False,
|
||||
use_auth_token: Union[bool, str, None] = None,
|
||||
local_files_only=False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
@@ -1036,6 +1039,8 @@ def cached_path(
|
||||
force_download: if True, re-download the file even if it's already cached in the cache dir.
|
||||
resume_download: if True, resume the download if incompletely received file is found.
|
||||
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
|
||||
use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
|
||||
will get token from ~/.huggingface.
|
||||
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
|
||||
file in a folder along the archive.
|
||||
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
|
||||
@@ -1063,6 +1068,7 @@ def cached_path(
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
user_agent=user_agent,
|
||||
use_auth_token=use_auth_token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
elif os.path.exists(url_or_filename):
|
||||
@@ -1125,11 +1131,11 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
return ua
|
||||
|
||||
|
||||
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
|
||||
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
Donwload remote file. Do not gobble up errors.
|
||||
"""
|
||||
headers = {"user-agent": http_user_agent(user_agent)}
|
||||
headers = copy.deepcopy(headers)
|
||||
if resume_size > 0:
|
||||
headers["Range"] = "bytes=%d-" % (resume_size,)
|
||||
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||
@@ -1159,6 +1165,7 @@ def get_from_cache(
|
||||
etag_timeout=10,
|
||||
resume_download=False,
|
||||
user_agent: Union[Dict, str, None] = None,
|
||||
use_auth_token: Union[bool, str, None] = None,
|
||||
local_files_only=False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
@@ -1178,11 +1185,19 @@ def get_from_cache(
|
||||
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
headers = {"user-agent": http_user_agent(user_agent)}
|
||||
if isinstance(use_auth_token, str):
|
||||
headers["authorization"] = "Bearer {}".format(use_auth_token)
|
||||
elif use_auth_token:
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||
headers["authorization"] = "Bearer {}".format(token)
|
||||
|
||||
url_to_download = url
|
||||
etag = None
|
||||
if not local_files_only:
|
||||
try:
|
||||
headers = {"user-agent": http_user_agent(user_agent)}
|
||||
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
||||
r.raise_for_status()
|
||||
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
||||
@@ -1272,7 +1287,7 @@ def get_from_cache(
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||
|
||||
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
|
||||
|
||||
logger.info("storing %s in cache at %s", url, cache_path)
|
||||
os.replace(temp_file.name, cache_path)
|
||||
|
||||
Reference in New Issue
Block a user