From fe78fe98ca56c093d0503590cbad0b39ce3326a0 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 19 Jan 2022 07:52:35 +0900 Subject: [PATCH] Enable tqdm toggling (#15167) * feature: enable tqdm toggle * test: add tqdm unit test * style: run linter * Update tests/test_tqdm_utils.py Co-authored-by: Stas Bekman * refactor: use tiny model, run linter * docs: add tqdm to logging * docs: add tqdm reference to `http_get` * style: run linter * Update docs/source/main_classes/logging.mdx Co-authored-by: Stas Bekman * fix: use `AutoConfig` for framework agnostic testing * chore: mv tqdm test to `test_logging.py` * feature: implement enable/disable functions * docs: mv docstring to comment * chore: mv tqdm functions to `logging.py` * docs: update docs to reference `enable/disable` funcs * test: update test to use `enable/disable` func * chore: update function reference in comment Co-authored-by: Stas Bekman --- docs/source/main_classes/logging.mdx | 6 +++ src/transformers/file_utils.py | 5 ++- src/transformers/utils/logging.py | 66 ++++++++++++++++++++++++++++ tests/test_logging.py | 18 +++++++- 4 files changed, 92 insertions(+), 3 deletions(-) diff --git a/docs/source/main_classes/logging.mdx b/docs/source/main_classes/logging.mdx index b707ca8698..ac0717443d 100644 --- a/docs/source/main_classes/logging.mdx +++ b/docs/source/main_classes/logging.mdx @@ -54,6 +54,8 @@ verbose to the most verbose), those levels (with their corresponding int values - `transformers.logging.INFO` (int value, 20): reports error, warnings and basic information. - `transformers.logging.DEBUG` (int value, 10): report all information. +By default, `tqdm` progress bars will be displayed during model download. [`logging.disable_progress_bar`] and [`logging.enable_progress_bar`] can be used to suppress or unsuppress this behavior. + ## Base setters [[autodoc]] logging.set_verbosity_error @@ -79,3 +81,7 @@ verbose to the most verbose), those levels (with their corresponding int values [[autodoc]] logging.enable_explicit_format [[autodoc]] logging.reset_format + +[[autodoc]] logging.enable_progress_bar + +[[autodoc]] logging.disable_progress_bar diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index f8599796cb..57eba81c27 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -45,12 +45,12 @@ from zipfile import ZipFile, is_zipfile import numpy as np from packaging import version -from tqdm.auto import tqdm import requests from filelock import FileLock from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami from requests.exceptions import HTTPError +from transformers.utils.logging import tqdm from transformers.utils.versions import importlib_metadata from . import __version__ @@ -1911,6 +1911,8 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers r.raise_for_status() content_length = r.headers.get("Content-Length") total = resume_size + int(content_length) if content_length is not None else None + # `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()` + # and can be set using `utils.logging.enable/disable_progress_bar()` progress = tqdm( unit="B", unit_scale=True, @@ -1918,7 +1920,6 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers total=total, initial=resume_size, desc="Downloading", - disable=bool(logging.get_verbosity() == logging.NOTSET), ) for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks diff --git a/src/transformers/utils/logging.py b/src/transformers/utils/logging.py index 659022a009..7cf2fb7a72 100644 --- a/src/transformers/utils/logging.py +++ b/src/transformers/utils/logging.py @@ -28,6 +28,8 @@ from logging import WARN # NOQA from logging import WARNING # NOQA from typing import Optional +from tqdm import auto as tqdm_lib + _lock = threading.Lock() _default_handler: Optional[logging.Handler] = None @@ -42,6 +44,8 @@ log_levels = { _default_log_level = logging.WARNING +_tqdm_active = True + def _get_default_logging_level(): """ @@ -276,3 +280,65 @@ def warning_advice(self, *args, **kwargs): logging.Logger.warning_advice = warning_advice + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar(): + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar(): + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False diff --git a/tests/test_logging.py b/tests/test_logging.py index 914a46a3c8..81940d2d3b 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -14,10 +14,12 @@ import os import unittest +from unittest.mock import patch import transformers.models.bart.tokenization_bart -from transformers import logging +from transformers import AutoConfig, logging from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context +from transformers.utils.logging import disable_progress_bar, enable_progress_bar class HfArgumentParserTest(unittest.TestCase): @@ -121,3 +123,17 @@ class HfArgumentParserTest(unittest.TestCase): with CaptureLogger(logger) as cl: logger.warning_advice(msg) self.assertEqual(cl.out, msg + "\n") + + +def test_set_progress_bar_enabled(): + TINY_MODEL = "hf-internal-testing/tiny-random-distilbert" + with patch("tqdm.auto.tqdm") as mock_tqdm: + disable_progress_bar() + _ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True) + mock_tqdm.assert_not_called() + + mock_tqdm.reset_mock() + + enable_progress_bar() + _ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True) + mock_tqdm.assert_called()