diff --git a/docs/source/testing.rst b/docs/source/testing.rst index a3c8f847aa..f057e8bbcf 100644 --- a/docs/source/testing.rst +++ b/docs/source/testing.rst @@ -1080,6 +1080,8 @@ If you need to capture both streams at once, use the parent :obj:`CaptureStd` cl function_that_writes_to_stdout_and_stderr() print(cs.err, cs.out) +Also, to aid debugging test issues, by default these context managers automatically replay the captured streams on exit +from the context. Capturing logger stream diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 593bb469b6..c10c07f788 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -610,34 +610,54 @@ class CaptureStd: """ Context manager to capture: - - stdout, clean it up and make it available via obj.out - - stderr, and make it available via obj.err + - stdout: replay it, clean it up and make it available via ``obj.out`` + - stderr: replay it and make it available via ``obj.err`` init arguments: - - out - capture stdout: True/False, default True - - err - capture stdout: True/False, default True + - out - capture stdout:`` True``/``False``, default ``True`` + - err - capture stdout: ``True``/``False``, default ``True`` + - replay - whether to replay or not: ``True``/``False``, default ``True``. By default each + captured stream gets replayed back on context's exit, so that one can see what the test was doing. If this is a + not wanted behavior and the captured data shouldn't be replayed, pass ``replay=False`` to disable this feature. Examples:: + # to capture stdout only with auto-replay with CaptureStdout() as cs: print("Secret message") - print(f"captured: {cs.out}") + assert "message" in cs.out + # to capture stderr only with auto-replay import sys with CaptureStderr() as cs: print("Warning: ", file=sys.stderr) - print(f"captured: {cs.err}") + assert "Warning" in cs.err - # to capture just one of the streams, but not the other + # to capture both streams with auto-replay + with CaptureStd() as cs: + print("Secret message") + print("Warning: ", file=sys.stderr) + assert "message" in cs.out + assert "Warning" in cs.err + + # to capture just one of the streams, and not the other, with auto-replay with CaptureStd(err=False) as cs: print("Secret message") - print(f"captured: {cs.out}") + assert "message" in cs.out # but best use the stream-specific subclasses + # to capture without auto-replay + with CaptureStd(replay=False) as cs: + print("Secret message") + assert "message" in cs.out + """ - def __init__(self, out=True, err=True): + def __init__(self, out=True, err=True, replay=True): + + self.replay = replay + if out: self.out_buf = StringIO() self.out = "error: CaptureStd context is unfinished yet, called too early" @@ -666,11 +686,17 @@ class CaptureStd: def __exit__(self, *exc): if self.out_buf: sys.stdout = self.out_old - self.out = apply_print_resets(self.out_buf.getvalue()) + captured = self.out_buf.getvalue() + if self.replay: + sys.stdout.write(captured) + self.out = apply_print_resets(captured) if self.err_buf: sys.stderr = self.err_old - self.err = self.err_buf.getvalue() + captured = self.err_buf.getvalue() + if self.replay: + sys.stderr.write(captured) + self.err = captured def __repr__(self): msg = "" @@ -690,15 +716,15 @@ class CaptureStd: class CaptureStdout(CaptureStd): """Same as CaptureStd but captures only stdout""" - def __init__(self): - super().__init__(err=False) + def __init__(self, replay=True): + super().__init__(err=False, replay=replay) class CaptureStderr(CaptureStd): """Same as CaptureStd but captures only stderr""" - def __init__(self): - super().__init__(out=False) + def __init__(self, replay=True): + super().__init__(out=False, replay=replay) class CaptureLogger: