testing utils: capturing std streams context manager (#6231)
* testing utils: capturing std streams context manager * style * missing import * add the origin of this code
This commit is contained in:
@@ -1,6 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
from .file_utils import _tf_available, _torch_available, _torch_tpu_available
|
from .file_utils import _tf_available, _torch_available, _torch_tpu_available
|
||||||
|
|
||||||
@@ -136,3 +139,119 @@ def require_torch_and_cuda(test_case):
|
|||||||
return unittest.skip("test requires CUDA")
|
return unittest.skip("test requires CUDA")
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Helper functions for dealing with testing text outputs
|
||||||
|
# The original code came from:
|
||||||
|
# https://github.com/fastai/fastai/blob/master/tests/utils/text.py
|
||||||
|
|
||||||
|
# When any function contains print() calls that get overwritten, like progress bars,
|
||||||
|
# a special care needs to be applied, since under pytest -s captured output (capsys
|
||||||
|
# or contextlib.redirect_stdout) contains any temporary printed strings, followed by
|
||||||
|
# \r's. This helper function ensures that the buffer will contain the same output
|
||||||
|
# with and without -s in pytest, by turning:
|
||||||
|
# foo bar\r tar mar\r final message
|
||||||
|
# into:
|
||||||
|
# final message
|
||||||
|
# it can handle a single string or a multiline buffer
|
||||||
|
def apply_print_resets(buf):
|
||||||
|
return re.sub(r"^.*\r", "", buf, 0, re.M)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_screenout(out, what):
|
||||||
|
out_pr = apply_print_resets(out).lower()
|
||||||
|
match_str = out_pr.find(what.lower())
|
||||||
|
assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"
|
||||||
|
|
||||||
|
|
||||||
|
class CaptureStd:
|
||||||
|
""" Context manager to capture:
|
||||||
|
stdout, clean it up and make it available via obj.out
|
||||||
|
stderr, and make it available via obj.err
|
||||||
|
|
||||||
|
init arguments:
|
||||||
|
- out - capture stdout: True/False, default True
|
||||||
|
- err - capture stdout: True/False, default True
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
with CaptureStdout() as cs:
|
||||||
|
print("Secret message")
|
||||||
|
print(f"captured: {cs.out}")
|
||||||
|
|
||||||
|
import sys
|
||||||
|
with CaptureStderr() as cs:
|
||||||
|
print("Warning: ", file=sys.stderr)
|
||||||
|
print(f"captured: {cs.err}")
|
||||||
|
|
||||||
|
# to capture just one of the streams, but not the other
|
||||||
|
with CaptureStd(err=False) as cs:
|
||||||
|
print("Secret message")
|
||||||
|
print(f"captured: {cs.out}")
|
||||||
|
# but best use the stream-specific subclasses
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, out=True, err=True):
|
||||||
|
if out:
|
||||||
|
self.out_buf = StringIO()
|
||||||
|
self.out = "error: CaptureStd context is unfinished yet, called too early"
|
||||||
|
else:
|
||||||
|
self.out_buf = None
|
||||||
|
self.out = "not capturing stdout"
|
||||||
|
|
||||||
|
if err:
|
||||||
|
self.err_buf = StringIO()
|
||||||
|
self.err = "error: CaptureStd context is unfinished yet, called too early"
|
||||||
|
else:
|
||||||
|
self.err_buf = None
|
||||||
|
self.err = "not capturing stderr"
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self.out_buf:
|
||||||
|
self.out_old = sys.stdout
|
||||||
|
sys.stdout = self.out_buf
|
||||||
|
|
||||||
|
if self.err_buf:
|
||||||
|
self.err_old = sys.stderr
|
||||||
|
sys.stderr = self.err_buf
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc):
|
||||||
|
if self.out_buf:
|
||||||
|
sys.stdout = self.out_old
|
||||||
|
self.out = apply_print_resets(self.out_buf.getvalue())
|
||||||
|
|
||||||
|
if self.err_buf:
|
||||||
|
sys.stderr = self.err_old
|
||||||
|
self.err = self.err_buf.getvalue()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
msg = ""
|
||||||
|
if self.out_buf:
|
||||||
|
msg += f"stdout: {self.out}\n"
|
||||||
|
if self.err_buf:
|
||||||
|
msg += f"stderr: {self.err}\n"
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
# in tests it's the best to capture only the stream that's wanted, otherwise
|
||||||
|
# it's easy to miss things, so unless you need to capture both streams, use the
|
||||||
|
# subclasses below (less typing). Or alternatively, configure `CaptureStd` to
|
||||||
|
# disable the stream you don't need to test.
|
||||||
|
|
||||||
|
|
||||||
|
class CaptureStdout(CaptureStd):
|
||||||
|
""" Same as CaptureStd but captures only stdout """
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(err=False)
|
||||||
|
|
||||||
|
|
||||||
|
class CaptureStderr(CaptureStd):
|
||||||
|
""" Same as CaptureStd but captures only stderr """
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(out=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user