Harmonize HF environment variables + other cleaning (#27564)
* Harmonize HF environment variables + other cleaning * backward compat * switch from HUGGINGFACE_HUB_CACHE to HF_HUB_CACHE * revert
This commit is contained in:
@@ -25,6 +25,8 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import try_to_load_from_cache
|
||||
|
||||
from .utils import (
|
||||
HF_MODULES_CACHE,
|
||||
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
||||
@@ -32,7 +34,6 @@ from .utils import (
|
||||
extract_commit_hash,
|
||||
is_offline_mode,
|
||||
logging,
|
||||
try_to_load_from_cache,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -31,13 +31,16 @@ from uuid import uuid4
|
||||
import huggingface_hub
|
||||
import requests
|
||||
from huggingface_hub import (
|
||||
_CACHED_NO_EXIST,
|
||||
CommitOperationAdd,
|
||||
constants,
|
||||
create_branch,
|
||||
create_commit,
|
||||
create_repo,
|
||||
get_hf_file_metadata,
|
||||
hf_hub_download,
|
||||
hf_hub_url,
|
||||
try_to_load_from_cache,
|
||||
)
|
||||
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
|
||||
from huggingface_hub.utils import (
|
||||
@@ -49,7 +52,9 @@ from huggingface_hub.utils import (
|
||||
RevisionNotFoundError,
|
||||
build_hf_headers,
|
||||
hf_raise_for_status,
|
||||
send_telemetry,
|
||||
)
|
||||
from huggingface_hub.utils._deprecation import _deprecate_method
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from . import __version__, logging
|
||||
@@ -75,17 +80,25 @@ def is_offline_mode():
|
||||
|
||||
|
||||
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
||||
default_cache_path = constants.default_cache_path
|
||||
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||
# New default cache, shared with the Datasets library
|
||||
hf_cache_home = os.path.expanduser(
|
||||
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
||||
)
|
||||
default_cache_path = os.path.join(hf_cache_home, "hub")
|
||||
|
||||
# Determine default cache directory. Lots of legacy environment variables to ensure backward compatibility.
|
||||
# The best way to set the cache path is with the environment variable HF_HOME. For more details, checkout this
|
||||
# documentation page: https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables.
|
||||
#
|
||||
# In code, use `HF_HUB_CACHE` as the default cache path. This variable is set by the library and is guaranteed
|
||||
# to be set to the right value.
|
||||
#
|
||||
# TODO: clean this for v5?
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", constants.HF_HUB_CACHE)
|
||||
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
||||
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
||||
|
||||
# Onetime move from the old location to the new one if no ENV variable has been set.
|
||||
if (
|
||||
os.path.isdir(old_default_cache_path)
|
||||
and not os.path.isdir(default_cache_path)
|
||||
and not os.path.isdir(constants.HF_HUB_CACHE)
|
||||
and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ
|
||||
and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ
|
||||
and "TRANSFORMERS_CACHE" not in os.environ
|
||||
@@ -97,16 +110,26 @@ if (
|
||||
" '~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should"
|
||||
" only see this message once."
|
||||
)
|
||||
shutil.move(old_default_cache_path, default_cache_path)
|
||||
shutil.move(old_default_cache_path, constants.HF_HUB_CACHE)
|
||||
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
||||
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
||||
HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
||||
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", HUGGINGFACE_HUB_CACHE)
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(constants.HF_HOME, "modules"))
|
||||
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
|
||||
SESSION_ID = uuid4().hex
|
||||
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES
|
||||
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", constants.HF_HUB_DISABLE_TELEMETRY) in ENV_VARS_TRUE_VALUES
|
||||
|
||||
# Add deprecation warning for old environment variables.
|
||||
for key in ("PYTORCH_PRETRAINED_BERT_CACHE", "PYTORCH_TRANSFORMERS_CACHE", "TRANSFORMERS_CACHE"):
|
||||
if os.getenv(key) is not None:
|
||||
warnings.warn(
|
||||
f"Using `{key}` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if os.getenv("DISABLE_TELEMETRY") is not None:
|
||||
warnings.warn(
|
||||
"Using `DISABLE_TELEMETRY` is deprecated and will be removed in v5 of Transformers. Use `HF_HUB_DISABLE_TELEMETRY` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
|
||||
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
|
||||
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
|
||||
@@ -126,15 +149,16 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_R
|
||||
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
|
||||
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
|
||||
|
||||
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
|
||||
_CACHED_NO_EXIST = object()
|
||||
|
||||
|
||||
def is_remote_url(url_or_filename):
|
||||
parsed = urlparse(url_or_filename)
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
|
||||
# TODO: remove this once fully deprecated
|
||||
# TODO? remove from './examples/research_projects/lxmert/utils.py' as well
|
||||
# TODO? remove from './examples/research_projects/visual_bert/utils.py' as well
|
||||
@_deprecate_method(version="4.39.0", message="This method is outdated and does not support the new cache system.")
|
||||
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
|
||||
"""
|
||||
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
|
||||
@@ -219,7 +243,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
return ua
|
||||
|
||||
|
||||
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]):
|
||||
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Extracts the commit hash from a resolved filename toward a cache file.
|
||||
"""
|
||||
@@ -233,73 +257,6 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
|
||||
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
|
||||
|
||||
|
||||
def try_to_load_from_cache(
|
||||
repo_id: str,
|
||||
filename: str,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
revision: Optional[str] = None,
|
||||
repo_type: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Explores the cache to return the latest cached file for a given revision if found.
|
||||
|
||||
This function will not raise any exception if the file in not cached.
|
||||
|
||||
Args:
|
||||
cache_dir (`str` or `os.PathLike`):
|
||||
The folder where the cached files lie.
|
||||
repo_id (`str`):
|
||||
The ID of the repo on huggingface.co.
|
||||
filename (`str`):
|
||||
The filename to look for inside `repo_id`.
|
||||
revision (`str`, *optional*):
|
||||
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
|
||||
provided either.
|
||||
repo_type (`str`, *optional*):
|
||||
The type of the repo.
|
||||
|
||||
Returns:
|
||||
`Optional[str]` or `_CACHED_NO_EXIST`:
|
||||
Will return `None` if the file was not cached. Otherwise:
|
||||
- The exact path to the cached file if it's found in the cache
|
||||
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
|
||||
cached.
|
||||
"""
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
|
||||
object_id = repo_id.replace("/", "--")
|
||||
if repo_type is None:
|
||||
repo_type = "model"
|
||||
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
|
||||
if not os.path.isdir(repo_cache):
|
||||
# No cache for this model
|
||||
return None
|
||||
for subfolder in ["refs", "snapshots"]:
|
||||
if not os.path.isdir(os.path.join(repo_cache, subfolder)):
|
||||
return None
|
||||
|
||||
# Resolve refs (for instance to convert main to the associated commit sha)
|
||||
cached_refs = os.listdir(os.path.join(repo_cache, "refs"))
|
||||
if revision in cached_refs:
|
||||
with open(os.path.join(repo_cache, "refs", revision)) as f:
|
||||
revision = f.read()
|
||||
|
||||
if os.path.isfile(os.path.join(repo_cache, ".no_exist", revision, filename)):
|
||||
return _CACHED_NO_EXIST
|
||||
|
||||
cached_shas = os.listdir(os.path.join(repo_cache, "snapshots"))
|
||||
if revision not in cached_shas:
|
||||
# No cache for this revision and we won't try to return a random revision
|
||||
return None
|
||||
|
||||
cached_file = os.path.join(repo_cache, "snapshots", revision, filename)
|
||||
return cached_file if os.path.isfile(cached_file) else None
|
||||
|
||||
|
||||
def cached_file(
|
||||
path_or_repo_id: Union[str, os.PathLike],
|
||||
filename: str,
|
||||
@@ -317,7 +274,7 @@ def cached_file(
|
||||
_raise_exceptions_for_connection_errors: bool = True,
|
||||
_commit_hash: Optional[str] = None,
|
||||
**deprecated_kwargs,
|
||||
):
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
||||
|
||||
@@ -369,7 +326,8 @@ def cached_file(
|
||||
```python
|
||||
# Download a model weight from the Hub and cache it.
|
||||
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
|
||||
```"""
|
||||
```
|
||||
"""
|
||||
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
@@ -499,6 +457,10 @@ def cached_file(
|
||||
return resolved_file
|
||||
|
||||
|
||||
# TODO: deprecate `get_file_from_repo` or document it differently?
|
||||
# Docstring is exactly the same as `cached_repo` but behavior is slightly different. If file is missing or if
|
||||
# there is a connection error, `cached_repo` will return None while `get_file_from_repo` will raise an error.
|
||||
# IMO we should keep only 1 method and have a single `raise_error` argument (to be discussed).
|
||||
def get_file_from_repo(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
filename: str,
|
||||
@@ -564,7 +526,8 @@ def get_file_from_repo(
|
||||
tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json")
|
||||
# This model does not have a tokenizer config so the result will be None.
|
||||
tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
|
||||
```"""
|
||||
```
|
||||
"""
|
||||
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
@@ -609,10 +572,11 @@ def download_url(url, proxies=None):
|
||||
f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in"
|
||||
" v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note"
|
||||
" that this is not compatible with the caching system (your file will be downloaded at each execution) or"
|
||||
" multiple processes (each process will download the file in a different temporary file)."
|
||||
" multiple processes (each process will download the file in a different temporary file).",
|
||||
FutureWarning,
|
||||
)
|
||||
tmp_file = tempfile.mkstemp()[1]
|
||||
with open(tmp_file, "wb") as f:
|
||||
tmp_fd, tmp_file = tempfile.mkstemp()
|
||||
with os.fdopen(tmp_fd, "wb") as f:
|
||||
http_get(url, f, proxies=proxies)
|
||||
return tmp_file
|
||||
|
||||
@@ -947,13 +911,10 @@ def send_example_telemetry(example_name, *example_args, framework="pytorch"):
|
||||
script_name = script_name.replace("_no_trainer", "")
|
||||
data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}"
|
||||
|
||||
headers = {"user-agent": http_user_agent(data)}
|
||||
try:
|
||||
r = requests.head(HUGGINGFACE_CO_EXAMPLES_TELEMETRY, headers=headers)
|
||||
r.raise_for_status()
|
||||
except Exception:
|
||||
# We don't want to error in case of connection errors of any kind.
|
||||
pass
|
||||
# Send telemetry in the background
|
||||
send_telemetry(
|
||||
topic="examples", library_name="transformers", library_version=__version__, user_agent=http_user_agent(data)
|
||||
)
|
||||
|
||||
|
||||
def convert_file_size_to_int(size: Union[int, str]):
|
||||
@@ -1258,7 +1219,7 @@ if cache_version < 1 and cache_is_not_empty:
|
||||
"`transformers.utils.move_cache()`."
|
||||
)
|
||||
try:
|
||||
if TRANSFORMERS_CACHE != default_cache_path:
|
||||
if TRANSFORMERS_CACHE != constants.HF_HUB_CACHE:
|
||||
# Users set some env variable to customize cache storage
|
||||
move_cache(TRANSFORMERS_CACHE, TRANSFORMERS_CACHE)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user