Clean up utils.hub using the latest from hf_hub (#18857)
* Clean up utils.hub using the latest from hf_hub * Adapt test * Address review comment * Fix test
This commit is contained in:
2
setup.py
2
setup.py
@@ -116,7 +116,7 @@ _deps = [
|
||||
"fugashi>=1.0",
|
||||
"GitPython<3.1.19",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub>=0.8.1,<1.0",
|
||||
"huggingface-hub>=0.9.0,<1.0",
|
||||
"importlib_metadata",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
"isort>=5.5.4",
|
||||
|
||||
@@ -22,7 +22,7 @@ deps = {
|
||||
"fugashi": "fugashi>=1.0",
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.8.1,<1.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.9.0,<1.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
"isort": "isort>=5.5.4",
|
||||
|
||||
@@ -21,7 +21,6 @@ import shutil
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from uuid import uuid4
|
||||
@@ -39,7 +38,12 @@ from huggingface_hub import (
|
||||
)
|
||||
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
|
||||
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
LocalEntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
)
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers.utils.logging import tqdm
|
||||
|
||||
@@ -249,28 +253,6 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
|
||||
return cached_file if os.path.isfile(cached_file) else None
|
||||
|
||||
|
||||
# If huggingface_hub changes the class of error for this to FileNotFoundError, we will be able to avoid that in the
|
||||
# future.
|
||||
LOCAL_FILES_ONLY_HF_ERROR = (
|
||||
"Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co "
|
||||
"look-ups and downloads online, set 'local_files_only' to False."
|
||||
)
|
||||
|
||||
|
||||
# In the future, this ugly contextmanager can be removed when huggingface_hub as a released version where we can
|
||||
# activate/deactivate progress bars.
|
||||
@contextmanager
|
||||
def _patch_hf_hub_tqdm():
|
||||
"""
|
||||
A context manager to make huggingface hub use the tqdm version of Transformers (which is controlled by some utils)
|
||||
in logging.
|
||||
"""
|
||||
old_tqdm = huggingface_hub.file_download.tqdm
|
||||
huggingface_hub.file_download.tqdm = tqdm
|
||||
yield
|
||||
huggingface_hub.file_download.tqdm = old_tqdm
|
||||
|
||||
|
||||
def cached_file(
|
||||
path_or_repo_id: Union[str, os.PathLike],
|
||||
filename: str,
|
||||
@@ -375,20 +357,19 @@ def cached_file(
|
||||
user_agent = http_user_agent(user_agent)
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
with _patch_hf_hub_tqdm():
|
||||
resolved_file = hf_hub_download(
|
||||
path_or_repo_id,
|
||||
filename,
|
||||
subfolder=None if len(subfolder) == 0 else subfolder,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
use_auth_token=use_auth_token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
resolved_file = hf_hub_download(
|
||||
path_or_repo_id,
|
||||
filename,
|
||||
subfolder=None if len(subfolder) == 0 else subfolder,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
use_auth_token=use_auth_token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
@@ -403,6 +384,19 @@ def cached_file(
|
||||
"for this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
|
||||
)
|
||||
except LocalEntryNotFoundError:
|
||||
# We try to see if we have a cached version (not up to date):
|
||||
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
|
||||
if resolved_file is not None:
|
||||
return resolved_file
|
||||
if not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors:
|
||||
return None
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
|
||||
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
|
||||
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
|
||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
if not _raise_exceptions_for_missing_entries:
|
||||
return None
|
||||
@@ -421,24 +415,6 @@ def cached_file(
|
||||
return None
|
||||
|
||||
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
|
||||
except ValueError as err:
|
||||
# HuggingFace Hub returns a ValueError for a missing file when local_files_only=True we need to catch it here
|
||||
# This could be caught above along in `EntryNotFoundError` if hf_hub sent a different error message here
|
||||
if LOCAL_FILES_ONLY_HF_ERROR in err.args[0] and local_files_only and not _raise_exceptions_for_missing_entries:
|
||||
return None
|
||||
|
||||
# Otherwise we try to see if we have a cached version (not up to date):
|
||||
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
|
||||
if resolved_file is not None:
|
||||
return resolved_file
|
||||
if not _raise_exceptions_for_connection_errors:
|
||||
return None
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
|
||||
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
|
||||
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
|
||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
|
||||
return resolved_file
|
||||
|
||||
|
||||
@@ -30,6 +30,8 @@ from typing import Optional
|
||||
|
||||
from tqdm import auto as tqdm_lib
|
||||
|
||||
import huggingface_hub.utils as hf_hub_utils
|
||||
|
||||
|
||||
_lock = threading.Lock()
|
||||
_default_handler: Optional[logging.Handler] = None
|
||||
@@ -336,9 +338,11 @@ def enable_progress_bar():
|
||||
"""Enable tqdm progress bar."""
|
||||
global _tqdm_active
|
||||
_tqdm_active = True
|
||||
hf_hub_utils.enable_progress_bars()
|
||||
|
||||
|
||||
def disable_progress_bar():
|
||||
"""Disable tqdm progress bar."""
|
||||
global _tqdm_active
|
||||
_tqdm_active = False
|
||||
hf_hub_utils.disable_progress_bars()
|
||||
|
||||
@@ -14,10 +14,10 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import transformers.models.bart.tokenization_bart
|
||||
from transformers import AutoConfig, logging
|
||||
from huggingface_hub.utils import are_progress_bars_disabled
|
||||
from transformers import logging
|
||||
from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context
|
||||
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
|
||||
|
||||
@@ -126,14 +126,8 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
|
||||
|
||||
def test_set_progress_bar_enabled():
|
||||
TINY_MODEL = "hf-internal-testing/tiny-random-distilbert"
|
||||
with patch("tqdm.auto.tqdm") as mock_tqdm:
|
||||
disable_progress_bar()
|
||||
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
|
||||
mock_tqdm.assert_not_called()
|
||||
disable_progress_bar()
|
||||
assert are_progress_bars_disabled()
|
||||
|
||||
mock_tqdm.reset_mock()
|
||||
|
||||
enable_progress_bar()
|
||||
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
|
||||
mock_tqdm.assert_called()
|
||||
enable_progress_bar()
|
||||
assert not are_progress_bars_disabled()
|
||||
|
||||
Reference in New Issue
Block a user