Fix example logs repeating themselves (#16669)
Move declaration of log streams to before tests, so that results won't get compounded on top of each other
This commit is contained in:
@@ -97,11 +97,12 @@ def is_cuda_and_apex_available():
|
||||
return is_using_cuda and is_apex_available()
|
||||
|
||||
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class ExamplesTests(TestCasePlus):
|
||||
def test_run_glue(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_glue.py
|
||||
@@ -130,9 +131,6 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||
|
||||
def test_run_clm(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_clm.py
|
||||
@@ -187,9 +185,6 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertIn('"n_head": 2', cl.out)
|
||||
|
||||
def test_run_mlm(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_mlm.py
|
||||
@@ -213,9 +208,6 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertLess(result["perplexity"], 42)
|
||||
|
||||
def test_run_ner(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
||||
epochs = 7 if get_gpu_count() > 1 else 2
|
||||
|
||||
@@ -247,9 +239,6 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertLess(result["eval_loss"], 0.5)
|
||||
|
||||
def test_run_squad(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_qa.py
|
||||
@@ -275,9 +264,6 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertGreaterEqual(result["eval_exact"], 30)
|
||||
|
||||
def test_run_squad_seq2seq(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_seq2seq_qa.py
|
||||
@@ -307,9 +293,6 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertGreaterEqual(result["eval_exact"], 30)
|
||||
|
||||
def test_run_swag(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_swag.py
|
||||
@@ -333,9 +316,6 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||
|
||||
def test_generation(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
|
||||
|
||||
if is_cuda_and_apex_available():
|
||||
@@ -351,9 +331,6 @@ class ExamplesTests(TestCasePlus):
|
||||
|
||||
@slow
|
||||
def test_run_summarization(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_summarization.py
|
||||
@@ -382,9 +359,6 @@ class ExamplesTests(TestCasePlus):
|
||||
|
||||
@slow
|
||||
def test_run_translation(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_translation.py
|
||||
@@ -414,9 +388,6 @@ class ExamplesTests(TestCasePlus):
|
||||
|
||||
@unittest.skip("This is currently broken.")
|
||||
def test_run_image_classification(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_image_classification.py
|
||||
@@ -446,9 +417,6 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||
|
||||
def test_run_speech_recognition_ctc(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_speech_recognition_ctc.py
|
||||
@@ -479,9 +447,6 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
def test_run_speech_recognition_seq2seq(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_speech_recognition_seq2seq.py
|
||||
@@ -512,9 +477,6 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
def test_run_audio_classification(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_audio_classification.py
|
||||
@@ -547,9 +509,6 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
def test_run_wav2vec2_pretraining(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_wav2vec2_pretraining_no_trainer.py
|
||||
@@ -577,9 +536,6 @@ class ExamplesTests(TestCasePlus):
|
||||
|
||||
@unittest.skip("This is currently broken.")
|
||||
def test_run_vit_mae_pretraining(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_mae.py
|
||||
|
||||
Reference in New Issue
Block a user