From 38c3cd52fb6b39e2253d055ea583537efb29cd31 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 2 Sep 2022 10:30:06 -0400 Subject: [PATCH] Clean up utils.hub using the latest from hf_hub (#18857) * Clean up utils.hub using the latest from hf_hub * Adapt test * Address review comment * Fix test --- setup.py | 2 +- src/transformers/dependency_versions_table.py | 2 +- src/transformers/utils/hub.py | 88 +++++++------------ src/transformers/utils/logging.py | 4 + tests/utils/test_logging.py | 18 ++-- 5 files changed, 44 insertions(+), 70 deletions(-) diff --git a/setup.py b/setup.py index e974ff9a2b..8f101357e8 100644 --- a/setup.py +++ b/setup.py @@ -116,7 +116,7 @@ _deps = [ "fugashi>=1.0", "GitPython<3.1.19", "hf-doc-builder>=0.3.0", - "huggingface-hub>=0.8.1,<1.0", + "huggingface-hub>=0.9.0,<1.0", "importlib_metadata", "ipadic>=1.0.0,<2.0", "isort>=5.5.4", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index c8f0f18793..58e4a2cd42 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -22,7 +22,7 @@ deps = { "fugashi": "fugashi>=1.0", "GitPython": "GitPython<3.1.19", "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub>=0.8.1,<1.0", + "huggingface-hub": "huggingface-hub>=0.9.0,<1.0", "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", "isort": "isort>=5.5.4", diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 163ad64ffa..9b1e9a5b85 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -21,7 +21,6 @@ import shutil import sys import traceback import warnings -from contextlib import contextmanager from pathlib import Path from typing import Dict, List, Optional, Tuple, Union from uuid import uuid4 @@ -39,7 +38,12 @@ from huggingface_hub import ( ) from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT from huggingface_hub.file_download import REGEX_COMMIT_HASH -from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from huggingface_hub.utils import ( + EntryNotFoundError, + LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) from requests.exceptions import HTTPError from transformers.utils.logging import tqdm @@ -249,28 +253,6 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h return cached_file if os.path.isfile(cached_file) else None -# If huggingface_hub changes the class of error for this to FileNotFoundError, we will be able to avoid that in the -# future. -LOCAL_FILES_ONLY_HF_ERROR = ( - "Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co " - "look-ups and downloads online, set 'local_files_only' to False." -) - - -# In the future, this ugly contextmanager can be removed when huggingface_hub as a released version where we can -# activate/deactivate progress bars. -@contextmanager -def _patch_hf_hub_tqdm(): - """ - A context manager to make huggingface hub use the tqdm version of Transformers (which is controlled by some utils) - in logging. - """ - old_tqdm = huggingface_hub.file_download.tqdm - huggingface_hub.file_download.tqdm = tqdm - yield - huggingface_hub.file_download.tqdm = old_tqdm - - def cached_file( path_or_repo_id: Union[str, os.PathLike], filename: str, @@ -375,20 +357,19 @@ def cached_file( user_agent = http_user_agent(user_agent) try: # Load from URL or cache if already cached - with _patch_hf_hub_tqdm(): - resolved_file = hf_hub_download( - path_or_repo_id, - filename, - subfolder=None if len(subfolder) == 0 else subfolder, - revision=revision, - cache_dir=cache_dir, - user_agent=user_agent, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - use_auth_token=use_auth_token, - local_files_only=local_files_only, - ) + resolved_file = hf_hub_download( + path_or_repo_id, + filename, + subfolder=None if len(subfolder) == 0 else subfolder, + revision=revision, + cache_dir=cache_dir, + user_agent=user_agent, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + ) except RepositoryNotFoundError: raise EnvironmentError( @@ -403,6 +384,19 @@ def cached_file( "for this model name. Check the model page at " f"'https://huggingface.co/{path_or_repo_id}' for available revisions." ) + except LocalEntryNotFoundError: + # We try to see if we have a cached version (not up to date): + resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision) + if resolved_file is not None: + return resolved_file + if not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors: + return None + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the" + f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named" + f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at" + " 'https://huggingface.co/docs/transformers/installation#offline-mode'." + ) except EntryNotFoundError: if not _raise_exceptions_for_missing_entries: return None @@ -421,24 +415,6 @@ def cached_file( return None raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}") - except ValueError as err: - # HuggingFace Hub returns a ValueError for a missing file when local_files_only=True we need to catch it here - # This could be caught above along in `EntryNotFoundError` if hf_hub sent a different error message here - if LOCAL_FILES_ONLY_HF_ERROR in err.args[0] and local_files_only and not _raise_exceptions_for_missing_entries: - return None - - # Otherwise we try to see if we have a cached version (not up to date): - resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision) - if resolved_file is not None: - return resolved_file - if not _raise_exceptions_for_connection_errors: - return None - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the" - f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named" - f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at" - " 'https://huggingface.co/docs/transformers/installation#offline-mode'." - ) return resolved_file diff --git a/src/transformers/utils/logging.py b/src/transformers/utils/logging.py index 91ecca7cfc..a98e2f30fd 100644 --- a/src/transformers/utils/logging.py +++ b/src/transformers/utils/logging.py @@ -30,6 +30,8 @@ from typing import Optional from tqdm import auto as tqdm_lib +import huggingface_hub.utils as hf_hub_utils + _lock = threading.Lock() _default_handler: Optional[logging.Handler] = None @@ -336,9 +338,11 @@ def enable_progress_bar(): """Enable tqdm progress bar.""" global _tqdm_active _tqdm_active = True + hf_hub_utils.enable_progress_bars() def disable_progress_bar(): """Disable tqdm progress bar.""" global _tqdm_active _tqdm_active = False + hf_hub_utils.disable_progress_bars() diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 81940d2d3b..81f3d9144a 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -14,10 +14,10 @@ import os import unittest -from unittest.mock import patch import transformers.models.bart.tokenization_bart -from transformers import AutoConfig, logging +from huggingface_hub.utils import are_progress_bars_disabled +from transformers import logging from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context from transformers.utils.logging import disable_progress_bar, enable_progress_bar @@ -126,14 +126,8 @@ class HfArgumentParserTest(unittest.TestCase): def test_set_progress_bar_enabled(): - TINY_MODEL = "hf-internal-testing/tiny-random-distilbert" - with patch("tqdm.auto.tqdm") as mock_tqdm: - disable_progress_bar() - _ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True) - mock_tqdm.assert_not_called() + disable_progress_bar() + assert are_progress_bars_disabled() - mock_tqdm.reset_mock() - - enable_progress_bar() - _ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True) - mock_tqdm.assert_called() + enable_progress_bar() + assert not are_progress_bars_disabled()