Use hash to clean the test dirs (#6475)
* Use hash to clean the test dirs * Use hash to clean the test dirs * Use hash to clean the test dirs * fix
This commit is contained in:
@@ -20,7 +20,7 @@ def get_setup_file():
|
|||||||
return args.f
|
return args.f
|
||||||
|
|
||||||
|
|
||||||
def clean_test_dir(path="./tests/fixtures/tests_samples/temp_dir"):
|
def clean_test_dir(path):
|
||||||
shutil.rmtree(path, ignore_errors=True)
|
shutil.rmtree(path, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -37,7 +37,6 @@ class PabeeTests(unittest.TestCase):
|
|||||||
--task_name mrpc
|
--task_name mrpc
|
||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
--do_eval
|
||||||
--output_dir ./tests/fixtures/tests_samples/temp_dir
|
|
||||||
--per_gpu_train_batch_size=2
|
--per_gpu_train_batch_size=2
|
||||||
--per_gpu_eval_batch_size=1
|
--per_gpu_eval_batch_size=1
|
||||||
--learning_rate=2e-5
|
--learning_rate=2e-5
|
||||||
@@ -46,10 +45,13 @@ class PabeeTests(unittest.TestCase):
|
|||||||
--overwrite_output_dir
|
--overwrite_output_dir
|
||||||
--seed=42
|
--seed=42
|
||||||
--max_seq_length=128
|
--max_seq_length=128
|
||||||
""".split()
|
"""
|
||||||
|
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
|
||||||
|
testargs += "--output_dir " + output_dir
|
||||||
|
testargs = testargs.split()
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
result = run_glue_with_pabee.main()
|
result = run_glue_with_pabee.main()
|
||||||
for value in result.values():
|
for value in result.values():
|
||||||
self.assertGreaterEqual(value, 0.75)
|
self.assertGreaterEqual(value, 0.75)
|
||||||
|
|
||||||
clean_test_dir()
|
clean_test_dir(output_dir)
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ def get_setup_file():
|
|||||||
return args.f
|
return args.f
|
||||||
|
|
||||||
|
|
||||||
def clean_test_dir(path="./tests/fixtures/tests_samples/temp_dir"):
|
def clean_test_dir(path):
|
||||||
shutil.rmtree(path, ignore_errors=True)
|
shutil.rmtree(path, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -68,7 +68,6 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
--task_name mrpc
|
--task_name mrpc
|
||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
--do_eval
|
||||||
--output_dir ./tests/fixtures/tests_samples/temp_dir
|
|
||||||
--per_device_train_batch_size=2
|
--per_device_train_batch_size=2
|
||||||
--per_device_eval_batch_size=1
|
--per_device_eval_batch_size=1
|
||||||
--learning_rate=1e-4
|
--learning_rate=1e-4
|
||||||
@@ -77,13 +76,16 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
--overwrite_output_dir
|
--overwrite_output_dir
|
||||||
--seed=42
|
--seed=42
|
||||||
--max_seq_length=128
|
--max_seq_length=128
|
||||||
""".split()
|
"""
|
||||||
|
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
|
||||||
|
testargs += "--output_dir " + output_dir
|
||||||
|
testargs = testargs.split()
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
result = run_glue.main()
|
result = run_glue.main()
|
||||||
del result["eval_loss"]
|
del result["eval_loss"]
|
||||||
for value in result.values():
|
for value in result.values():
|
||||||
self.assertGreaterEqual(value, 0.75)
|
self.assertGreaterEqual(value, 0.75)
|
||||||
clean_test_dir()
|
clean_test_dir(output_dir)
|
||||||
|
|
||||||
def test_run_pl_glue(self):
|
def test_run_pl_glue(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
@@ -96,13 +98,15 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
--task mrpc
|
--task mrpc
|
||||||
--do_train
|
--do_train
|
||||||
--do_predict
|
--do_predict
|
||||||
--output_dir ./tests/fixtures/tests_samples/temp_dir
|
|
||||||
--train_batch_size=32
|
--train_batch_size=32
|
||||||
--learning_rate=1e-4
|
--learning_rate=1e-4
|
||||||
--num_train_epochs=1
|
--num_train_epochs=1
|
||||||
--seed=42
|
--seed=42
|
||||||
--max_seq_length=128
|
--max_seq_length=128
|
||||||
""".split()
|
"""
|
||||||
|
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
|
||||||
|
testargs += "--output_dir " + output_dir
|
||||||
|
testargs = testargs.split()
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
testargs += ["--fp16", "--gpus=1"]
|
testargs += ["--fp16", "--gpus=1"]
|
||||||
@@ -119,7 +123,7 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
# for k, v in result.items():
|
# for k, v in result.items():
|
||||||
# self.assertGreaterEqual(v, 0.75, f"({k})")
|
# self.assertGreaterEqual(v, 0.75, f"({k})")
|
||||||
#
|
#
|
||||||
clean_test_dir()
|
clean_test_dir(output_dir)
|
||||||
|
|
||||||
def test_run_language_modeling(self):
|
def test_run_language_modeling(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
@@ -133,17 +137,19 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
--line_by_line
|
--line_by_line
|
||||||
--train_data_file ./tests/fixtures/sample_text.txt
|
--train_data_file ./tests/fixtures/sample_text.txt
|
||||||
--eval_data_file ./tests/fixtures/sample_text.txt
|
--eval_data_file ./tests/fixtures/sample_text.txt
|
||||||
--output_dir ./tests/fixtures/tests_samples/temp_dir
|
|
||||||
--overwrite_output_dir
|
--overwrite_output_dir
|
||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
--do_eval
|
||||||
--num_train_epochs=1
|
--num_train_epochs=1
|
||||||
--no_cuda
|
--no_cuda
|
||||||
""".split()
|
"""
|
||||||
|
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
|
||||||
|
testargs += "--output_dir " + output_dir
|
||||||
|
testargs = testargs.split()
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
result = run_language_modeling.main()
|
result = run_language_modeling.main()
|
||||||
self.assertLess(result["perplexity"], 35)
|
self.assertLess(result["perplexity"], 35)
|
||||||
clean_test_dir()
|
clean_test_dir(output_dir)
|
||||||
|
|
||||||
def test_run_squad(self):
|
def test_run_squad(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
@@ -154,7 +160,6 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
--model_type=distilbert
|
--model_type=distilbert
|
||||||
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
|
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
|
||||||
--data_dir=./tests/fixtures/tests_samples/SQUAD
|
--data_dir=./tests/fixtures/tests_samples/SQUAD
|
||||||
--output_dir=./tests/fixtures/tests_samples/temp_dir
|
|
||||||
--max_steps=10
|
--max_steps=10
|
||||||
--warmup_steps=2
|
--warmup_steps=2
|
||||||
--do_train
|
--do_train
|
||||||
@@ -165,12 +170,15 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
--per_gpu_eval_batch_size=1
|
--per_gpu_eval_batch_size=1
|
||||||
--overwrite_output_dir
|
--overwrite_output_dir
|
||||||
--seed=42
|
--seed=42
|
||||||
""".split()
|
"""
|
||||||
|
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
|
||||||
|
testargs += "--output_dir " + output_dir
|
||||||
|
testargs = testargs.split()
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
result = run_squad.main()
|
result = run_squad.main()
|
||||||
self.assertGreaterEqual(result["f1"], 25)
|
self.assertGreaterEqual(result["f1"], 25)
|
||||||
self.assertGreaterEqual(result["exact"], 21)
|
self.assertGreaterEqual(result["exact"], 21)
|
||||||
clean_test_dir()
|
clean_test_dir(output_dir)
|
||||||
|
|
||||||
def test_generation(self):
|
def test_generation(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
|||||||
Reference in New Issue
Block a user