Clean up utils.hub using the latest from hf_hub (#18857)
* Clean up utils.hub using the latest from hf_hub * Adapt test * Address review comment * Fix test
This commit is contained in:
2
setup.py
2
setup.py
@@ -116,7 +116,7 @@ _deps = [
|
|||||||
"fugashi>=1.0",
|
"fugashi>=1.0",
|
||||||
"GitPython<3.1.19",
|
"GitPython<3.1.19",
|
||||||
"hf-doc-builder>=0.3.0",
|
"hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub>=0.8.1,<1.0",
|
"huggingface-hub>=0.9.0,<1.0",
|
||||||
"importlib_metadata",
|
"importlib_metadata",
|
||||||
"ipadic>=1.0.0,<2.0",
|
"ipadic>=1.0.0,<2.0",
|
||||||
"isort>=5.5.4",
|
"isort>=5.5.4",
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ deps = {
|
|||||||
"fugashi": "fugashi>=1.0",
|
"fugashi": "fugashi>=1.0",
|
||||||
"GitPython": "GitPython<3.1.19",
|
"GitPython": "GitPython<3.1.19",
|
||||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub": "huggingface-hub>=0.8.1,<1.0",
|
"huggingface-hub": "huggingface-hub>=0.9.0,<1.0",
|
||||||
"importlib_metadata": "importlib_metadata",
|
"importlib_metadata": "importlib_metadata",
|
||||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||||
"isort": "isort>=5.5.4",
|
"isort": "isort>=5.5.4",
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import shutil
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
@@ -39,7 +38,12 @@ from huggingface_hub import (
|
|||||||
)
|
)
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
|
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
|
||||||
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
||||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
from huggingface_hub.utils import (
|
||||||
|
EntryNotFoundError,
|
||||||
|
LocalEntryNotFoundError,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
|
)
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers.utils.logging import tqdm
|
from transformers.utils.logging import tqdm
|
||||||
|
|
||||||
@@ -249,28 +253,6 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
|
|||||||
return cached_file if os.path.isfile(cached_file) else None
|
return cached_file if os.path.isfile(cached_file) else None
|
||||||
|
|
||||||
|
|
||||||
# If huggingface_hub changes the class of error for this to FileNotFoundError, we will be able to avoid that in the
|
|
||||||
# future.
|
|
||||||
LOCAL_FILES_ONLY_HF_ERROR = (
|
|
||||||
"Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co "
|
|
||||||
"look-ups and downloads online, set 'local_files_only' to False."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# In the future, this ugly contextmanager can be removed when huggingface_hub as a released version where we can
|
|
||||||
# activate/deactivate progress bars.
|
|
||||||
@contextmanager
|
|
||||||
def _patch_hf_hub_tqdm():
|
|
||||||
"""
|
|
||||||
A context manager to make huggingface hub use the tqdm version of Transformers (which is controlled by some utils)
|
|
||||||
in logging.
|
|
||||||
"""
|
|
||||||
old_tqdm = huggingface_hub.file_download.tqdm
|
|
||||||
huggingface_hub.file_download.tqdm = tqdm
|
|
||||||
yield
|
|
||||||
huggingface_hub.file_download.tqdm = old_tqdm
|
|
||||||
|
|
||||||
|
|
||||||
def cached_file(
|
def cached_file(
|
||||||
path_or_repo_id: Union[str, os.PathLike],
|
path_or_repo_id: Union[str, os.PathLike],
|
||||||
filename: str,
|
filename: str,
|
||||||
@@ -375,7 +357,6 @@ def cached_file(
|
|||||||
user_agent = http_user_agent(user_agent)
|
user_agent = http_user_agent(user_agent)
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
with _patch_hf_hub_tqdm():
|
|
||||||
resolved_file = hf_hub_download(
|
resolved_file = hf_hub_download(
|
||||||
path_or_repo_id,
|
path_or_repo_id,
|
||||||
filename,
|
filename,
|
||||||
@@ -403,6 +384,19 @@ def cached_file(
|
|||||||
"for this model name. Check the model page at "
|
"for this model name. Check the model page at "
|
||||||
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
|
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
|
||||||
)
|
)
|
||||||
|
except LocalEntryNotFoundError:
|
||||||
|
# We try to see if we have a cached version (not up to date):
|
||||||
|
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
|
||||||
|
if resolved_file is not None:
|
||||||
|
return resolved_file
|
||||||
|
if not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors:
|
||||||
|
return None
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
|
||||||
|
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
|
||||||
|
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
|
||||||
|
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||||
|
)
|
||||||
except EntryNotFoundError:
|
except EntryNotFoundError:
|
||||||
if not _raise_exceptions_for_missing_entries:
|
if not _raise_exceptions_for_missing_entries:
|
||||||
return None
|
return None
|
||||||
@@ -421,24 +415,6 @@ def cached_file(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
|
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
|
||||||
except ValueError as err:
|
|
||||||
# HuggingFace Hub returns a ValueError for a missing file when local_files_only=True we need to catch it here
|
|
||||||
# This could be caught above along in `EntryNotFoundError` if hf_hub sent a different error message here
|
|
||||||
if LOCAL_FILES_ONLY_HF_ERROR in err.args[0] and local_files_only and not _raise_exceptions_for_missing_entries:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Otherwise we try to see if we have a cached version (not up to date):
|
|
||||||
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
|
|
||||||
if resolved_file is not None:
|
|
||||||
return resolved_file
|
|
||||||
if not _raise_exceptions_for_connection_errors:
|
|
||||||
return None
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
|
|
||||||
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
|
|
||||||
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
|
|
||||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
|
||||||
)
|
|
||||||
|
|
||||||
return resolved_file
|
return resolved_file
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ from typing import Optional
|
|||||||
|
|
||||||
from tqdm import auto as tqdm_lib
|
from tqdm import auto as tqdm_lib
|
||||||
|
|
||||||
|
import huggingface_hub.utils as hf_hub_utils
|
||||||
|
|
||||||
|
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
_default_handler: Optional[logging.Handler] = None
|
_default_handler: Optional[logging.Handler] = None
|
||||||
@@ -336,9 +338,11 @@ def enable_progress_bar():
|
|||||||
"""Enable tqdm progress bar."""
|
"""Enable tqdm progress bar."""
|
||||||
global _tqdm_active
|
global _tqdm_active
|
||||||
_tqdm_active = True
|
_tqdm_active = True
|
||||||
|
hf_hub_utils.enable_progress_bars()
|
||||||
|
|
||||||
|
|
||||||
def disable_progress_bar():
|
def disable_progress_bar():
|
||||||
"""Disable tqdm progress bar."""
|
"""Disable tqdm progress bar."""
|
||||||
global _tqdm_active
|
global _tqdm_active
|
||||||
_tqdm_active = False
|
_tqdm_active = False
|
||||||
|
hf_hub_utils.disable_progress_bars()
|
||||||
|
|||||||
@@ -14,10 +14,10 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import transformers.models.bart.tokenization_bart
|
import transformers.models.bart.tokenization_bart
|
||||||
from transformers import AutoConfig, logging
|
from huggingface_hub.utils import are_progress_bars_disabled
|
||||||
|
from transformers import logging
|
||||||
from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context
|
from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context
|
||||||
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
|
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
|
||||||
|
|
||||||
@@ -126,14 +126,8 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
def test_set_progress_bar_enabled():
|
def test_set_progress_bar_enabled():
|
||||||
TINY_MODEL = "hf-internal-testing/tiny-random-distilbert"
|
|
||||||
with patch("tqdm.auto.tqdm") as mock_tqdm:
|
|
||||||
disable_progress_bar()
|
disable_progress_bar()
|
||||||
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
|
assert are_progress_bars_disabled()
|
||||||
mock_tqdm.assert_not_called()
|
|
||||||
|
|
||||||
mock_tqdm.reset_mock()
|
|
||||||
|
|
||||||
enable_progress_bar()
|
enable_progress_bar()
|
||||||
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
|
assert not are_progress_bars_disabled()
|
||||||
mock_tqdm.assert_called()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user