From 83984a61c657023a69e78951b338e378a0f866c2 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 11 Aug 2020 00:56:47 -0700 Subject: [PATCH] testing utils: capturing std streams context manager (#6231) * testing utils: capturing std streams context manager * style * missing import * add the origin of this code --- src/transformers/testing_utils.py | 119 ++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 1fb9e83868..71c205fbb0 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1,6 +1,9 @@ import os +import re +import sys import unittest from distutils.util import strtobool +from io import StringIO from .file_utils import _tf_available, _torch_available, _torch_tpu_available @@ -136,3 +139,119 @@ def require_torch_and_cuda(test_case): return unittest.skip("test requires CUDA") else: return test_case + + +# +# Helper functions for dealing with testing text outputs +# The original code came from: +# https://github.com/fastai/fastai/blob/master/tests/utils/text.py + +# When any function contains print() calls that get overwritten, like progress bars, +# a special care needs to be applied, since under pytest -s captured output (capsys +# or contextlib.redirect_stdout) contains any temporary printed strings, followed by +# \r's. This helper function ensures that the buffer will contain the same output +# with and without -s in pytest, by turning: +# foo bar\r tar mar\r final message +# into: +# final message +# it can handle a single string or a multiline buffer +def apply_print_resets(buf): + return re.sub(r"^.*\r", "", buf, 0, re.M) + + +def assert_screenout(out, what): + out_pr = apply_print_resets(out).lower() + match_str = out_pr.find(what.lower()) + assert match_str != -1, f"expecting to find {what} in output: f{out_pr}" + + +class CaptureStd: + """ Context manager to capture: + stdout, clean it up and make it available via obj.out + stderr, and make it available via obj.err + + init arguments: + - out - capture stdout: True/False, default True + - err - capture stdout: True/False, default True + + Examples: + + with CaptureStdout() as cs: + print("Secret message") + print(f"captured: {cs.out}") + + import sys + with CaptureStderr() as cs: + print("Warning: ", file=sys.stderr) + print(f"captured: {cs.err}") + + # to capture just one of the streams, but not the other + with CaptureStd(err=False) as cs: + print("Secret message") + print(f"captured: {cs.out}") + # but best use the stream-specific subclasses + + """ + + def __init__(self, out=True, err=True): + if out: + self.out_buf = StringIO() + self.out = "error: CaptureStd context is unfinished yet, called too early" + else: + self.out_buf = None + self.out = "not capturing stdout" + + if err: + self.err_buf = StringIO() + self.err = "error: CaptureStd context is unfinished yet, called too early" + else: + self.err_buf = None + self.err = "not capturing stderr" + + def __enter__(self): + if self.out_buf: + self.out_old = sys.stdout + sys.stdout = self.out_buf + + if self.err_buf: + self.err_old = sys.stderr + sys.stderr = self.err_buf + + return self + + def __exit__(self, *exc): + if self.out_buf: + sys.stdout = self.out_old + self.out = apply_print_resets(self.out_buf.getvalue()) + + if self.err_buf: + sys.stderr = self.err_old + self.err = self.err_buf.getvalue() + + def __repr__(self): + msg = "" + if self.out_buf: + msg += f"stdout: {self.out}\n" + if self.err_buf: + msg += f"stderr: {self.err}\n" + return msg + + +# in tests it's the best to capture only the stream that's wanted, otherwise +# it's easy to miss things, so unless you need to capture both streams, use the +# subclasses below (less typing). Or alternatively, configure `CaptureStd` to +# disable the stream you don't need to test. + + +class CaptureStdout(CaptureStd): + """ Same as CaptureStd but captures only stdout """ + + def __init__(self): + super().__init__(err=False) + + +class CaptureStderr(CaptureStd): + """ Same as CaptureStd but captures only stderr """ + + def __init__(self): + super().__init__(out=False)