Fix example logs repeating themselves (#16669)

Move declaration of log streams to before tests, so that results won't get compounded on top of each other
This commit is contained in:
Zachary Mueller
2022-04-11 16:25:16 -04:00
committed by GitHub
parent dce33f2150
commit 69233cf03b
5 changed files with 19 additions and 99 deletions

View File

@@ -70,11 +70,12 @@ def get_results(output_dir, split="eval"):
return results
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class ExamplesTests(TestCasePlus):
def test_run_glue(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_glue.py
@@ -98,9 +99,6 @@ class ExamplesTests(TestCasePlus):
@slow
def test_run_clm(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_clm_flax.py
@@ -125,9 +123,6 @@ class ExamplesTests(TestCasePlus):
@slow
def test_run_summarization(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_summarization.py
@@ -158,9 +153,6 @@ class ExamplesTests(TestCasePlus):
@slow
def test_run_mlm(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_mlm.py
@@ -185,9 +177,6 @@ class ExamplesTests(TestCasePlus):
@slow
def test_run_t5_mlm(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_t5_mlm_flax.py
@@ -212,9 +201,6 @@ class ExamplesTests(TestCasePlus):
@slow
def test_run_ner(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs = 7 if get_gpu_count() > 1 else 2
@@ -245,9 +231,6 @@ class ExamplesTests(TestCasePlus):
@slow
def test_run_qa(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_qa.py