From 9dbe4094f2ac2b27d9aa47a6326e8a3a3198f21c Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 17 Aug 2020 05:12:19 -0700 Subject: [PATCH] [testing] a new TestCasePlus subclass + get_auto_remove_tmp_dir() (#6494) * [testing] switch to a new TestCasePlus + get_auto_remove_tmp_dir() for auto-removal of tmp dirs * respect after=True for tempfile, simplify code * comments * comment fix * put `before` last in args, so can make debug even faster --- .../test_run_glue_with_pabee.py | 23 ++--- examples/test_examples.py | 58 +++++------- src/transformers/testing_utils.py | 92 +++++++++++++++++++ 3 files changed, 124 insertions(+), 49 deletions(-) diff --git a/examples/bert-loses-patience/test_run_glue_with_pabee.py b/examples/bert-loses-patience/test_run_glue_with_pabee.py index e626d220c6..22c6f4de06 100644 --- a/examples/bert-loses-patience/test_run_glue_with_pabee.py +++ b/examples/bert-loses-patience/test_run_glue_with_pabee.py @@ -1,11 +1,10 @@ import argparse import logging -import shutil import sys -import unittest from unittest.mock import patch import run_glue_with_pabee +from transformers.testing_utils import TestCasePlus logging.basicConfig(level=logging.DEBUG) @@ -20,20 +19,19 @@ def get_setup_file(): return args.f -def clean_test_dir(path): - shutil.rmtree(path, ignore_errors=True) - - -class PabeeTests(unittest.TestCase): +class PabeeTests(TestCasePlus): def test_run_glue(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) - testargs = """ + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" run_glue_with_pabee.py --model_type albert --model_name_or_path albert-base-v2 --data_dir ./tests/fixtures/tests_samples/MRPC/ + --output_dir {tmp_dir} + --overwrite_output_dir --task_name mrpc --do_train --do_eval @@ -42,16 +40,11 @@ class PabeeTests(unittest.TestCase): --learning_rate=2e-5 --max_steps=50 --warmup_steps=2 - --overwrite_output_dir --seed=42 --max_seq_length=128 - """ - output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs)) - testargs += "--output_dir " + output_dir - testargs = testargs.split() + """.split() + with patch.object(sys, "argv", testargs): result = run_glue_with_pabee.main() for value in result.values(): self.assertGreaterEqual(value, 0.75) - - clean_test_dir(output_dir) diff --git a/examples/test_examples.py b/examples/test_examples.py index ea4117d6fd..90debab95e 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -17,13 +17,13 @@ import argparse import logging import os -import shutil import sys -import unittest from unittest.mock import patch import torch +from transformers.testing_utils import TestCasePlus + SRC_DIRS = [ os.path.join(os.path.dirname(__file__), dirname) @@ -52,19 +52,18 @@ def get_setup_file(): return args.f -def clean_test_dir(path): - shutil.rmtree(path, ignore_errors=True) - - -class ExamplesTests(unittest.TestCase): +class ExamplesTests(TestCasePlus): def test_run_glue(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) - testargs = """ + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" run_glue.py --model_name_or_path distilbert-base-uncased --data_dir ./tests/fixtures/tests_samples/MRPC/ + --output_dir {tmp_dir} + --overwrite_output_dir --task_name mrpc --do_train --do_eval @@ -73,28 +72,26 @@ class ExamplesTests(unittest.TestCase): --learning_rate=1e-4 --max_steps=10 --warmup_steps=2 - --overwrite_output_dir --seed=42 --max_seq_length=128 - """ - output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs)) - testargs += "--output_dir " + output_dir - testargs = testargs.split() + """.split() + with patch.object(sys, "argv", testargs): result = run_glue.main() del result["eval_loss"] for value in result.values(): self.assertGreaterEqual(value, 0.75) - clean_test_dir(output_dir) def test_run_pl_glue(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) - testargs = """ + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" run_pl_glue.py --model_name_or_path bert-base-cased --data_dir ./tests/fixtures/tests_samples/MRPC/ + --output_dir {tmp_dir} --task mrpc --do_train --do_predict @@ -103,11 +100,7 @@ class ExamplesTests(unittest.TestCase): --num_train_epochs=1 --seed=42 --max_seq_length=128 - """ - output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs)) - testargs += "--output_dir " + output_dir - testargs = testargs.split() - + """.split() if torch.cuda.is_available(): testargs += ["--fp16", "--gpus=1"] @@ -123,13 +116,13 @@ class ExamplesTests(unittest.TestCase): # for k, v in result.items(): # self.assertGreaterEqual(v, 0.75, f"({k})") # - clean_test_dir(output_dir) def test_run_language_modeling(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) - testargs = """ + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" run_language_modeling.py --model_name_or_path distilroberta-base --model_type roberta @@ -137,29 +130,30 @@ class ExamplesTests(unittest.TestCase): --line_by_line --train_data_file ./tests/fixtures/sample_text.txt --eval_data_file ./tests/fixtures/sample_text.txt + --output_dir {tmp_dir} --overwrite_output_dir --do_train --do_eval --num_train_epochs=1 --no_cuda - """ - output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs)) - testargs += "--output_dir " + output_dir - testargs = testargs.split() + """.split() + with patch.object(sys, "argv", testargs): result = run_language_modeling.main() self.assertLess(result["perplexity"], 35) - clean_test_dir(output_dir) def test_run_squad(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) - testargs = """ + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" run_squad.py --model_type=distilbert --model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad --data_dir=./tests/fixtures/tests_samples/SQUAD + --output_dir {tmp_dir} + --overwrite_output_dir --max_steps=10 --warmup_steps=2 --do_train @@ -168,17 +162,13 @@ class ExamplesTests(unittest.TestCase): --learning_rate=2e-4 --per_gpu_train_batch_size=2 --per_gpu_eval_batch_size=1 - --overwrite_output_dir --seed=42 - """ - output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs)) - testargs += "--output_dir " + output_dir - testargs = testargs.split() + """.split() + with patch.object(sys, "argv", testargs): result = run_squad.main() self.assertGreaterEqual(result["f1"], 25) self.assertGreaterEqual(result["exact"], 21) - clean_test_dir(output_dir) def test_generation(self): stream_handler = logging.StreamHandler(sys.stdout) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 71c205fbb0..a3339a0fce 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1,9 +1,12 @@ import os import re +import shutil import sys +import tempfile import unittest from distutils.util import strtobool from io import StringIO +from pathlib import Path from .file_utils import _tf_available, _torch_available, _torch_tpu_available @@ -255,3 +258,92 @@ class CaptureStderr(CaptureStd): def __init__(self): super().__init__(out=False) + + +class TestCasePlus(unittest.TestCase): + """This class extends `unittest.TestCase` with additional features. + + Feature 1: Flexible auto-removable temp dirs which are guaranteed to get + removed at the end of test. + + In all the following scenarios the temp dir will be auto-removed at the end + of test, unless `after=False`. + + # 1. create a unique temp dir, `tmp_dir` will contain the path to the created temp dir + def test_whatever(self): + tmp_dir = self.get_auto_remove_tmp_dir() + + # 2. create a temp dir of my choice and delete it at the end - useful for debug when you want to + # monitor a specific directory + def test_whatever(self): + tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test") + + # 3. create a temp dir of my choice and do not delete it at the end - useful for when you want + # to look at the temp results + def test_whatever(self): + tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", after=False) + + # 4. create a temp dir of my choice and ensure to delete it right away - useful for when you + # disabled deletion in the previous test run and want to make sure the that tmp dir is empty + # before the new test is run + def test_whatever(self): + tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", before=True) + + Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the + project repository checkout are allowed if an explicit `tmp_dir` is used, so + that by mistake no `/tmp` or similar important part of the filesystem will + get nuked. i.e. please always pass paths that start with `./` + + Note 2: Each test can register multiple temp dirs and they all will get + auto-removed, unless requested otherwise. + + """ + + def setUp(self): + self.teardown_tmp_dirs = [] + + def get_auto_remove_tmp_dir(self, tmp_dir=None, after=True, before=False): + """ + Args: + tmp_dir (:obj:`string`, `optional`, defaults to :obj:`None`): + use this path, if None a unique path will be assigned + before (:obj:`bool`, `optional`, defaults to :obj:`False`): + if `True` and tmp dir already exists make sure to empty it right away + after (:obj:`bool`, `optional`, defaults to :obj:`True`): + delete the tmp dir at the end of the test + + Returns: + tmp_dir(:obj:`string`): + either the same value as passed via `tmp_dir` or the path to the auto-created tmp dir + """ + if tmp_dir is not None: + # using provided path + path = Path(tmp_dir).resolve() + + # to avoid nuking parts of the filesystem, only relative paths are allowed + if not tmp_dir.startswith("./"): + raise ValueError( + f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`" + ) + + # ensure the dir is empty to start with + if before is True and path.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) + + path.mkdir(parents=True, exist_ok=True) + + else: + # using unique tmp dir (always empty, regardless of `before`) + tmp_dir = tempfile.mkdtemp() + + if after is True: + # register for deletion + self.teardown_tmp_dirs.append(tmp_dir) + + return tmp_dir + + def tearDown(self): + # remove registered temp dirs + for path in self.teardown_tmp_dirs: + shutil.rmtree(path, ignore_errors=True) + self.teardown_tmp_dirs = []