[testing] a new TestCasePlus subclass + get_auto_remove_tmp_dir() (#6494)
* [testing] switch to a new TestCasePlus + get_auto_remove_tmp_dir() for auto-removal of tmp dirs * respect after=True for tempfile, simplify code * comments * comment fix * put `before` last in args, so can make debug even faster
This commit is contained in:
@@ -1,11 +1,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import run_glue_with_pabee
|
import run_glue_with_pabee
|
||||||
|
from transformers.testing_utils import TestCasePlus
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@@ -20,20 +19,19 @@ def get_setup_file():
|
|||||||
return args.f
|
return args.f
|
||||||
|
|
||||||
|
|
||||||
def clean_test_dir(path):
|
class PabeeTests(TestCasePlus):
|
||||||
shutil.rmtree(path, ignore_errors=True)
|
|
||||||
|
|
||||||
|
|
||||||
class PabeeTests(unittest.TestCase):
|
|
||||||
def test_run_glue(self):
|
def test_run_glue(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
testargs = """
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
run_glue_with_pabee.py
|
run_glue_with_pabee.py
|
||||||
--model_type albert
|
--model_type albert
|
||||||
--model_name_or_path albert-base-v2
|
--model_name_or_path albert-base-v2
|
||||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--overwrite_output_dir
|
||||||
--task_name mrpc
|
--task_name mrpc
|
||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
--do_eval
|
||||||
@@ -42,16 +40,11 @@ class PabeeTests(unittest.TestCase):
|
|||||||
--learning_rate=2e-5
|
--learning_rate=2e-5
|
||||||
--max_steps=50
|
--max_steps=50
|
||||||
--warmup_steps=2
|
--warmup_steps=2
|
||||||
--overwrite_output_dir
|
|
||||||
--seed=42
|
--seed=42
|
||||||
--max_seq_length=128
|
--max_seq_length=128
|
||||||
"""
|
""".split()
|
||||||
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
|
|
||||||
testargs += "--output_dir " + output_dir
|
|
||||||
testargs = testargs.split()
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
result = run_glue_with_pabee.main()
|
result = run_glue_with_pabee.main()
|
||||||
for value in result.values():
|
for value in result.values():
|
||||||
self.assertGreaterEqual(value, 0.75)
|
self.assertGreaterEqual(value, 0.75)
|
||||||
|
|
||||||
clean_test_dir(output_dir)
|
|
||||||
|
|||||||
@@ -17,13 +17,13 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from transformers.testing_utils import TestCasePlus
|
||||||
|
|
||||||
|
|
||||||
SRC_DIRS = [
|
SRC_DIRS = [
|
||||||
os.path.join(os.path.dirname(__file__), dirname)
|
os.path.join(os.path.dirname(__file__), dirname)
|
||||||
@@ -52,19 +52,18 @@ def get_setup_file():
|
|||||||
return args.f
|
return args.f
|
||||||
|
|
||||||
|
|
||||||
def clean_test_dir(path):
|
class ExamplesTests(TestCasePlus):
|
||||||
shutil.rmtree(path, ignore_errors=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ExamplesTests(unittest.TestCase):
|
|
||||||
def test_run_glue(self):
|
def test_run_glue(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
testargs = """
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
run_glue.py
|
run_glue.py
|
||||||
--model_name_or_path distilbert-base-uncased
|
--model_name_or_path distilbert-base-uncased
|
||||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--overwrite_output_dir
|
||||||
--task_name mrpc
|
--task_name mrpc
|
||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
--do_eval
|
||||||
@@ -73,28 +72,26 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
--learning_rate=1e-4
|
--learning_rate=1e-4
|
||||||
--max_steps=10
|
--max_steps=10
|
||||||
--warmup_steps=2
|
--warmup_steps=2
|
||||||
--overwrite_output_dir
|
|
||||||
--seed=42
|
--seed=42
|
||||||
--max_seq_length=128
|
--max_seq_length=128
|
||||||
"""
|
""".split()
|
||||||
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
|
|
||||||
testargs += "--output_dir " + output_dir
|
|
||||||
testargs = testargs.split()
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
result = run_glue.main()
|
result = run_glue.main()
|
||||||
del result["eval_loss"]
|
del result["eval_loss"]
|
||||||
for value in result.values():
|
for value in result.values():
|
||||||
self.assertGreaterEqual(value, 0.75)
|
self.assertGreaterEqual(value, 0.75)
|
||||||
clean_test_dir(output_dir)
|
|
||||||
|
|
||||||
def test_run_pl_glue(self):
|
def test_run_pl_glue(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
testargs = """
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
run_pl_glue.py
|
run_pl_glue.py
|
||||||
--model_name_or_path bert-base-cased
|
--model_name_or_path bert-base-cased
|
||||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||||
|
--output_dir {tmp_dir}
|
||||||
--task mrpc
|
--task mrpc
|
||||||
--do_train
|
--do_train
|
||||||
--do_predict
|
--do_predict
|
||||||
@@ -103,11 +100,7 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
--num_train_epochs=1
|
--num_train_epochs=1
|
||||||
--seed=42
|
--seed=42
|
||||||
--max_seq_length=128
|
--max_seq_length=128
|
||||||
"""
|
""".split()
|
||||||
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
|
|
||||||
testargs += "--output_dir " + output_dir
|
|
||||||
testargs = testargs.split()
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
testargs += ["--fp16", "--gpus=1"]
|
testargs += ["--fp16", "--gpus=1"]
|
||||||
|
|
||||||
@@ -123,13 +116,13 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
# for k, v in result.items():
|
# for k, v in result.items():
|
||||||
# self.assertGreaterEqual(v, 0.75, f"({k})")
|
# self.assertGreaterEqual(v, 0.75, f"({k})")
|
||||||
#
|
#
|
||||||
clean_test_dir(output_dir)
|
|
||||||
|
|
||||||
def test_run_language_modeling(self):
|
def test_run_language_modeling(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
testargs = """
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
run_language_modeling.py
|
run_language_modeling.py
|
||||||
--model_name_or_path distilroberta-base
|
--model_name_or_path distilroberta-base
|
||||||
--model_type roberta
|
--model_type roberta
|
||||||
@@ -137,29 +130,30 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
--line_by_line
|
--line_by_line
|
||||||
--train_data_file ./tests/fixtures/sample_text.txt
|
--train_data_file ./tests/fixtures/sample_text.txt
|
||||||
--eval_data_file ./tests/fixtures/sample_text.txt
|
--eval_data_file ./tests/fixtures/sample_text.txt
|
||||||
|
--output_dir {tmp_dir}
|
||||||
--overwrite_output_dir
|
--overwrite_output_dir
|
||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
--do_eval
|
||||||
--num_train_epochs=1
|
--num_train_epochs=1
|
||||||
--no_cuda
|
--no_cuda
|
||||||
"""
|
""".split()
|
||||||
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
|
|
||||||
testargs += "--output_dir " + output_dir
|
|
||||||
testargs = testargs.split()
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
result = run_language_modeling.main()
|
result = run_language_modeling.main()
|
||||||
self.assertLess(result["perplexity"], 35)
|
self.assertLess(result["perplexity"], 35)
|
||||||
clean_test_dir(output_dir)
|
|
||||||
|
|
||||||
def test_run_squad(self):
|
def test_run_squad(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
testargs = """
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
run_squad.py
|
run_squad.py
|
||||||
--model_type=distilbert
|
--model_type=distilbert
|
||||||
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
|
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
|
||||||
--data_dir=./tests/fixtures/tests_samples/SQUAD
|
--data_dir=./tests/fixtures/tests_samples/SQUAD
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--overwrite_output_dir
|
||||||
--max_steps=10
|
--max_steps=10
|
||||||
--warmup_steps=2
|
--warmup_steps=2
|
||||||
--do_train
|
--do_train
|
||||||
@@ -168,17 +162,13 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
--learning_rate=2e-4
|
--learning_rate=2e-4
|
||||||
--per_gpu_train_batch_size=2
|
--per_gpu_train_batch_size=2
|
||||||
--per_gpu_eval_batch_size=1
|
--per_gpu_eval_batch_size=1
|
||||||
--overwrite_output_dir
|
|
||||||
--seed=42
|
--seed=42
|
||||||
"""
|
""".split()
|
||||||
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
|
|
||||||
testargs += "--output_dir " + output_dir
|
|
||||||
testargs = testargs.split()
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
result = run_squad.main()
|
result = run_squad.main()
|
||||||
self.assertGreaterEqual(result["f1"], 25)
|
self.assertGreaterEqual(result["f1"], 25)
|
||||||
self.assertGreaterEqual(result["exact"], 21)
|
self.assertGreaterEqual(result["exact"], 21)
|
||||||
clean_test_dir(output_dir)
|
|
||||||
|
|
||||||
def test_generation(self):
|
def test_generation(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from .file_utils import _tf_available, _torch_available, _torch_tpu_available
|
from .file_utils import _tf_available, _torch_available, _torch_tpu_available
|
||||||
|
|
||||||
@@ -255,3 +258,92 @@ class CaptureStderr(CaptureStd):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(out=False)
|
super().__init__(out=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCasePlus(unittest.TestCase):
|
||||||
|
"""This class extends `unittest.TestCase` with additional features.
|
||||||
|
|
||||||
|
Feature 1: Flexible auto-removable temp dirs which are guaranteed to get
|
||||||
|
removed at the end of test.
|
||||||
|
|
||||||
|
In all the following scenarios the temp dir will be auto-removed at the end
|
||||||
|
of test, unless `after=False`.
|
||||||
|
|
||||||
|
# 1. create a unique temp dir, `tmp_dir` will contain the path to the created temp dir
|
||||||
|
def test_whatever(self):
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
|
# 2. create a temp dir of my choice and delete it at the end - useful for debug when you want to
|
||||||
|
# monitor a specific directory
|
||||||
|
def test_whatever(self):
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test")
|
||||||
|
|
||||||
|
# 3. create a temp dir of my choice and do not delete it at the end - useful for when you want
|
||||||
|
# to look at the temp results
|
||||||
|
def test_whatever(self):
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", after=False)
|
||||||
|
|
||||||
|
# 4. create a temp dir of my choice and ensure to delete it right away - useful for when you
|
||||||
|
# disabled deletion in the previous test run and want to make sure the that tmp dir is empty
|
||||||
|
# before the new test is run
|
||||||
|
def test_whatever(self):
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", before=True)
|
||||||
|
|
||||||
|
Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the
|
||||||
|
project repository checkout are allowed if an explicit `tmp_dir` is used, so
|
||||||
|
that by mistake no `/tmp` or similar important part of the filesystem will
|
||||||
|
get nuked. i.e. please always pass paths that start with `./`
|
||||||
|
|
||||||
|
Note 2: Each test can register multiple temp dirs and they all will get
|
||||||
|
auto-removed, unless requested otherwise.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.teardown_tmp_dirs = []
|
||||||
|
|
||||||
|
def get_auto_remove_tmp_dir(self, tmp_dir=None, after=True, before=False):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tmp_dir (:obj:`string`, `optional`, defaults to :obj:`None`):
|
||||||
|
use this path, if None a unique path will be assigned
|
||||||
|
before (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
if `True` and tmp dir already exists make sure to empty it right away
|
||||||
|
after (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
delete the tmp dir at the end of the test
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tmp_dir(:obj:`string`):
|
||||||
|
either the same value as passed via `tmp_dir` or the path to the auto-created tmp dir
|
||||||
|
"""
|
||||||
|
if tmp_dir is not None:
|
||||||
|
# using provided path
|
||||||
|
path = Path(tmp_dir).resolve()
|
||||||
|
|
||||||
|
# to avoid nuking parts of the filesystem, only relative paths are allowed
|
||||||
|
if not tmp_dir.startswith("./"):
|
||||||
|
raise ValueError(
|
||||||
|
f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ensure the dir is empty to start with
|
||||||
|
if before is True and path.exists():
|
||||||
|
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# using unique tmp dir (always empty, regardless of `before`)
|
||||||
|
tmp_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
if after is True:
|
||||||
|
# register for deletion
|
||||||
|
self.teardown_tmp_dirs.append(tmp_dir)
|
||||||
|
|
||||||
|
return tmp_dir
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
# remove registered temp dirs
|
||||||
|
for path in self.teardown_tmp_dirs:
|
||||||
|
shutil.rmtree(path, ignore_errors=True)
|
||||||
|
self.teardown_tmp_dirs = []
|
||||||
|
|||||||
Reference in New Issue
Block a user