* skip decorators: docs, tests, bugs * another important note * style * bloody style * add @pytest.mark.parametrize * add note * no idea what it wants :(
450 lines
13 KiB
Python
450 lines
13 KiB
Python
import inspect
|
|
import logging
|
|
import os
|
|
import re
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from distutils.util import strtobool
|
|
from io import StringIO
|
|
from pathlib import Path
|
|
|
|
from .file_utils import _datasets_available, _faiss_available, _tf_available, _torch_available, _torch_tpu_available
|
|
|
|
|
|
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
|
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
|
|
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
|
|
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
|
|
|
|
|
|
def parse_flag_from_env(key, default=False):
|
|
try:
|
|
value = os.environ[key]
|
|
except KeyError:
|
|
# KEY isn't set, default to `default`.
|
|
_value = default
|
|
else:
|
|
# KEY is set, convert it to True or False.
|
|
try:
|
|
_value = strtobool(value)
|
|
except ValueError:
|
|
# More values are supported, but let's keep the message simple.
|
|
raise ValueError("If set, {} must be yes or no.".format(key))
|
|
return _value
|
|
|
|
|
|
def parse_int_from_env(key, default=None):
|
|
try:
|
|
value = os.environ[key]
|
|
except KeyError:
|
|
_value = default
|
|
else:
|
|
try:
|
|
_value = int(value)
|
|
except ValueError:
|
|
raise ValueError("If set, {} must be a int.".format(key))
|
|
return _value
|
|
|
|
|
|
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
|
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
|
|
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
|
|
|
|
|
|
def slow(test_case):
|
|
"""
|
|
Decorator marking a test as slow.
|
|
|
|
Slow tests are skipped by default. Set the RUN_SLOW environment variable
|
|
to a truthy value to run them.
|
|
|
|
"""
|
|
if not _run_slow_tests:
|
|
return unittest.skip("test is slow")(test_case)
|
|
else:
|
|
return test_case
|
|
|
|
|
|
def custom_tokenizers(test_case):
|
|
"""
|
|
Decorator marking a test for a custom tokenizer.
|
|
|
|
Custom tokenizers require additional dependencies, and are skipped
|
|
by default. Set the RUN_CUSTOM_TOKENIZERS environment variable
|
|
to a truthy value to run them.
|
|
"""
|
|
if not _run_custom_tokenizers:
|
|
return unittest.skip("test of custom tokenizers")(test_case)
|
|
else:
|
|
return test_case
|
|
|
|
|
|
def require_torch(test_case):
|
|
"""
|
|
Decorator marking a test that requires PyTorch.
|
|
|
|
These tests are skipped when PyTorch isn't installed.
|
|
|
|
"""
|
|
if not _torch_available:
|
|
return unittest.skip("test requires PyTorch")(test_case)
|
|
else:
|
|
return test_case
|
|
|
|
|
|
def require_tf(test_case):
|
|
"""
|
|
Decorator marking a test that requires TensorFlow.
|
|
|
|
These tests are skipped when TensorFlow isn't installed.
|
|
|
|
"""
|
|
if not _tf_available:
|
|
return unittest.skip("test requires TensorFlow")(test_case)
|
|
else:
|
|
return test_case
|
|
|
|
|
|
def require_multigpu(test_case):
|
|
"""
|
|
Decorator marking a test that requires a multi-GPU setup (in PyTorch).
|
|
|
|
These tests are skipped on a machine without multiple GPUs.
|
|
|
|
To run *only* the multigpu tests, assuming all test names contain multigpu:
|
|
$ pytest -sv ./tests -k "multigpu"
|
|
"""
|
|
if not _torch_available:
|
|
return unittest.skip("test requires PyTorch")(test_case)
|
|
|
|
import torch
|
|
|
|
if torch.cuda.device_count() < 2:
|
|
return unittest.skip("test requires multiple GPUs")(test_case)
|
|
else:
|
|
return test_case
|
|
|
|
|
|
def require_non_multigpu(test_case):
|
|
"""
|
|
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
|
|
"""
|
|
if not _torch_available:
|
|
return unittest.skip("test requires PyTorch")(test_case)
|
|
|
|
import torch
|
|
|
|
if torch.cuda.device_count() > 1:
|
|
return unittest.skip("test requires 0 or 1 GPU")(test_case)
|
|
else:
|
|
return test_case
|
|
|
|
|
|
def require_torch_tpu(test_case):
|
|
"""
|
|
Decorator marking a test that requires a TPU (in PyTorch).
|
|
"""
|
|
if not _torch_tpu_available:
|
|
return unittest.skip("test requires PyTorch TPU")
|
|
else:
|
|
return test_case
|
|
|
|
|
|
if _torch_available:
|
|
# Set the USE_CUDA environment variable to select a GPU.
|
|
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"
|
|
else:
|
|
torch_device = None
|
|
|
|
|
|
def require_torch_and_cuda(test_case):
|
|
"""Decorator marking a test that requires CUDA and PyTorch. """
|
|
if torch_device != "cuda":
|
|
return unittest.skip("test requires CUDA")(test_case)
|
|
else:
|
|
return test_case
|
|
|
|
|
|
def require_datasets(test_case):
|
|
"""Decorator marking a test that requires datasets."""
|
|
|
|
if not _datasets_available:
|
|
return unittest.skip("test requires `datasets`")(test_case)
|
|
else:
|
|
return test_case
|
|
|
|
|
|
def require_faiss(test_case):
|
|
"""Decorator marking a test that requires faiss."""
|
|
if not _faiss_available:
|
|
return unittest.skip("test requires `faiss`")(test_case)
|
|
else:
|
|
return test_case
|
|
|
|
|
|
def get_tests_dir():
|
|
"""
|
|
returns the full path to the `tests` dir, so that the tests can be invoked from anywhere
|
|
"""
|
|
# this function caller's __file__
|
|
caller__file__ = inspect.stack()[1][1]
|
|
return os.path.abspath(os.path.dirname(caller__file__))
|
|
|
|
|
|
#
|
|
# Helper functions for dealing with testing text outputs
|
|
# The original code came from:
|
|
# https://github.com/fastai/fastai/blob/master/tests/utils/text.py
|
|
|
|
# When any function contains print() calls that get overwritten, like progress bars,
|
|
# a special care needs to be applied, since under pytest -s captured output (capsys
|
|
# or contextlib.redirect_stdout) contains any temporary printed strings, followed by
|
|
# \r's. This helper function ensures that the buffer will contain the same output
|
|
# with and without -s in pytest, by turning:
|
|
# foo bar\r tar mar\r final message
|
|
# into:
|
|
# final message
|
|
# it can handle a single string or a multiline buffer
|
|
def apply_print_resets(buf):
|
|
return re.sub(r"^.*\r", "", buf, 0, re.M)
|
|
|
|
|
|
def assert_screenout(out, what):
|
|
out_pr = apply_print_resets(out).lower()
|
|
match_str = out_pr.find(what.lower())
|
|
assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"
|
|
|
|
|
|
class CaptureStd:
|
|
"""Context manager to capture:
|
|
stdout, clean it up and make it available via obj.out
|
|
stderr, and make it available via obj.err
|
|
|
|
init arguments:
|
|
- out - capture stdout: True/False, default True
|
|
- err - capture stdout: True/False, default True
|
|
|
|
Examples:
|
|
|
|
with CaptureStdout() as cs:
|
|
print("Secret message")
|
|
print(f"captured: {cs.out}")
|
|
|
|
import sys
|
|
with CaptureStderr() as cs:
|
|
print("Warning: ", file=sys.stderr)
|
|
print(f"captured: {cs.err}")
|
|
|
|
# to capture just one of the streams, but not the other
|
|
with CaptureStd(err=False) as cs:
|
|
print("Secret message")
|
|
print(f"captured: {cs.out}")
|
|
# but best use the stream-specific subclasses
|
|
|
|
"""
|
|
|
|
def __init__(self, out=True, err=True):
|
|
if out:
|
|
self.out_buf = StringIO()
|
|
self.out = "error: CaptureStd context is unfinished yet, called too early"
|
|
else:
|
|
self.out_buf = None
|
|
self.out = "not capturing stdout"
|
|
|
|
if err:
|
|
self.err_buf = StringIO()
|
|
self.err = "error: CaptureStd context is unfinished yet, called too early"
|
|
else:
|
|
self.err_buf = None
|
|
self.err = "not capturing stderr"
|
|
|
|
def __enter__(self):
|
|
if self.out_buf:
|
|
self.out_old = sys.stdout
|
|
sys.stdout = self.out_buf
|
|
|
|
if self.err_buf:
|
|
self.err_old = sys.stderr
|
|
sys.stderr = self.err_buf
|
|
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
if self.out_buf:
|
|
sys.stdout = self.out_old
|
|
self.out = apply_print_resets(self.out_buf.getvalue())
|
|
|
|
if self.err_buf:
|
|
sys.stderr = self.err_old
|
|
self.err = self.err_buf.getvalue()
|
|
|
|
def __repr__(self):
|
|
msg = ""
|
|
if self.out_buf:
|
|
msg += f"stdout: {self.out}\n"
|
|
if self.err_buf:
|
|
msg += f"stderr: {self.err}\n"
|
|
return msg
|
|
|
|
|
|
# in tests it's the best to capture only the stream that's wanted, otherwise
|
|
# it's easy to miss things, so unless you need to capture both streams, use the
|
|
# subclasses below (less typing). Or alternatively, configure `CaptureStd` to
|
|
# disable the stream you don't need to test.
|
|
|
|
|
|
class CaptureStdout(CaptureStd):
|
|
""" Same as CaptureStd but captures only stdout """
|
|
|
|
def __init__(self):
|
|
super().__init__(err=False)
|
|
|
|
|
|
class CaptureStderr(CaptureStd):
|
|
""" Same as CaptureStd but captures only stderr """
|
|
|
|
def __init__(self):
|
|
super().__init__(out=False)
|
|
|
|
|
|
class CaptureLogger:
|
|
"""Context manager to capture `logging` streams
|
|
|
|
Args:
|
|
- logger: 'logging` logger object
|
|
|
|
Results:
|
|
The captured output is available via `self.out`
|
|
|
|
Example:
|
|
|
|
>>> from transformers import logging
|
|
>>> from transformers.testing_utils import CaptureLogger
|
|
|
|
>>> msg = "Testing 1, 2, 3"
|
|
>>> logging.set_verbosity_info()
|
|
>>> logger = logging.get_logger("transformers.tokenization_bart")
|
|
>>> with CaptureLogger(logger) as cl:
|
|
... logger.info(msg)
|
|
>>> assert cl.out, msg+"\n"
|
|
"""
|
|
|
|
def __init__(self, logger):
|
|
self.logger = logger
|
|
self.io = StringIO()
|
|
self.sh = logging.StreamHandler(self.io)
|
|
self.out = ""
|
|
|
|
def __enter__(self):
|
|
self.logger.addHandler(self.sh)
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
self.logger.removeHandler(self.sh)
|
|
self.out = self.io.getvalue()
|
|
|
|
def __repr__(self):
|
|
return f"captured: {self.out}\n"
|
|
|
|
|
|
class TestCasePlus(unittest.TestCase):
|
|
"""This class extends `unittest.TestCase` with additional features.
|
|
|
|
Feature 1: Flexible auto-removable temp dirs which are guaranteed to get
|
|
removed at the end of test.
|
|
|
|
In all the following scenarios the temp dir will be auto-removed at the end
|
|
of test, unless `after=False`.
|
|
|
|
# 1. create a unique temp dir, `tmp_dir` will contain the path to the created temp dir
|
|
def test_whatever(self):
|
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
|
|
|
# 2. create a temp dir of my choice and delete it at the end - useful for debug when you want to
|
|
# monitor a specific directory
|
|
def test_whatever(self):
|
|
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test")
|
|
|
|
# 3. create a temp dir of my choice and do not delete it at the end - useful for when you want
|
|
# to look at the temp results
|
|
def test_whatever(self):
|
|
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", after=False)
|
|
|
|
# 4. create a temp dir of my choice and ensure to delete it right away - useful for when you
|
|
# disabled deletion in the previous test run and want to make sure the that tmp dir is empty
|
|
# before the new test is run
|
|
def test_whatever(self):
|
|
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", before=True)
|
|
|
|
Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the
|
|
project repository checkout are allowed if an explicit `tmp_dir` is used, so
|
|
that by mistake no `/tmp` or similar important part of the filesystem will
|
|
get nuked. i.e. please always pass paths that start with `./`
|
|
|
|
Note 2: Each test can register multiple temp dirs and they all will get
|
|
auto-removed, unless requested otherwise.
|
|
|
|
"""
|
|
|
|
def setUp(self):
|
|
self.teardown_tmp_dirs = []
|
|
|
|
def get_auto_remove_tmp_dir(self, tmp_dir=None, after=True, before=False):
|
|
"""
|
|
Args:
|
|
tmp_dir (:obj:`string`, `optional`):
|
|
use this path, if None a unique path will be assigned
|
|
before (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
|
if `True` and tmp dir already exists make sure to empty it right away
|
|
after (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
|
delete the tmp dir at the end of the test
|
|
|
|
Returns:
|
|
tmp_dir(:obj:`string`):
|
|
either the same value as passed via `tmp_dir` or the path to the auto-created tmp dir
|
|
"""
|
|
if tmp_dir is not None:
|
|
# using provided path
|
|
path = Path(tmp_dir).resolve()
|
|
|
|
# to avoid nuking parts of the filesystem, only relative paths are allowed
|
|
if not tmp_dir.startswith("./"):
|
|
raise ValueError(
|
|
f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
|
|
)
|
|
|
|
# ensure the dir is empty to start with
|
|
if before is True and path.exists():
|
|
shutil.rmtree(tmp_dir, ignore_errors=True)
|
|
|
|
path.mkdir(parents=True, exist_ok=True)
|
|
|
|
else:
|
|
# using unique tmp dir (always empty, regardless of `before`)
|
|
tmp_dir = tempfile.mkdtemp()
|
|
|
|
if after is True:
|
|
# register for deletion
|
|
self.teardown_tmp_dirs.append(tmp_dir)
|
|
|
|
return tmp_dir
|
|
|
|
def tearDown(self):
|
|
# remove registered temp dirs
|
|
for path in self.teardown_tmp_dirs:
|
|
shutil.rmtree(path, ignore_errors=True)
|
|
self.teardown_tmp_dirs = []
|
|
|
|
|
|
def mockenv(**kwargs):
|
|
"""this is a convenience wrapper, that allows this:
|
|
|
|
@mockenv(USE_CUDA=True, USE_TF=False)
|
|
def test_something():
|
|
use_cuda = os.getenv("USE_CUDA", False)
|
|
use_tf = os.getenv("USE_TF", False)
|
|
"""
|
|
return unittest.mock.patch.dict(os.environ, kwargs)
|