adding TRANSFORMERS_VERBOSITY env var (#6961)

* introduce TRANSFORMERS_VERBOSITY env var + test + test helpers

* cleanup

* remove helper function
This commit is contained in:
Stas Bekman
2020-09-09 01:08:01 -07:00
committed by GitHub
parent f0fc0aea6b
commit d0963486c1
4 changed files with 163 additions and 4 deletions

View File

@@ -1,4 +1,5 @@
import inspect
import logging
import os
import re
import shutil
@@ -270,6 +271,46 @@ class CaptureStderr(CaptureStd):
super().__init__(out=False)
class CaptureLogger:
"""Context manager to capture `logging` streams
Args:
- logger: 'logging` logger object
Results:
The captured output is available via `self.out`
Example:
from transformers import logging
from transformers.testing_utils import CaptureLogger
msg = "Testing 1, 2, 3"
logging.set_verbosity_info()
logger = logging.get_logger("transformers.tokenization_bart")
with CaptureLogger(logger) as cl:
logger.info(msg)
assert cl.out, msg+"\n"
"""
def __init__(self, logger):
self.logger = logger
self.io = StringIO()
self.sh = logging.StreamHandler(self.io)
self.out = ""
def __enter__(self):
self.logger.addHandler(self.sh)
return self
def __exit__(self, *exc):
self.logger.removeHandler(self.sh)
self.out = self.io.getvalue()
def __repr__(self):
return f"captured: {self.out}\n"
class TestCasePlus(unittest.TestCase):
"""This class extends `unittest.TestCase` with additional features.
@@ -357,3 +398,14 @@ class TestCasePlus(unittest.TestCase):
for path in self.teardown_tmp_dirs:
shutil.rmtree(path, ignore_errors=True)
self.teardown_tmp_dirs = []
def mockenv(**kwargs):
"""this is a convenience wrapper, that allows this:
@mockenv(USE_CUDA=True, USE_TF=False)
def test_something():
use_cuda = os.getenv("USE_CUDA", False)
use_tf = os.getenv("USE_TF", False)
"""
return unittest.mock.patch.dict(os.environ, kwargs)

View File

@@ -15,6 +15,7 @@
""" Logging utilities. """
import logging
import os
import threading
from logging import CRITICAL # NOQA
from logging import DEBUG # NOQA
@@ -30,6 +31,33 @@ from typing import Optional
_lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
_default_log_level = logging.WARNING
def _get_default_logging_level():
"""
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level.
If it is not - fall back to ``_default_log_level``
"""
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
if env_level_str:
if env_level_str in log_levels:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, "
f"has to be one of: { ', '.join(log_levels.keys()) }"
)
return _default_log_level
def _get_library_name() -> str:
@@ -54,7 +82,7 @@ def _configure_library_root_logger() -> None:
# Apply our default configuration to the library root logger.
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(logging.WARN)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False