Fix example logs repeating themselves (#16669)
Move declaration of log streams to before tests, so that results won't get compounded on top of each other
This commit is contained in:
@@ -70,11 +70,12 @@ def get_results(output_dir, split="eval"):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class ExamplesTests(TestCasePlus):
|
|
||||||
def test_run_glue(self):
|
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class ExamplesTests(TestCasePlus):
|
||||||
|
def test_run_glue(self):
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_glue.py
|
run_glue.py
|
||||||
@@ -98,9 +99,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_run_clm(self):
|
def test_run_clm(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_clm_flax.py
|
run_clm_flax.py
|
||||||
@@ -125,9 +123,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_run_summarization(self):
|
def test_run_summarization(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_summarization.py
|
run_summarization.py
|
||||||
@@ -158,9 +153,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_run_mlm(self):
|
def test_run_mlm(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_mlm.py
|
run_mlm.py
|
||||||
@@ -185,9 +177,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_run_t5_mlm(self):
|
def test_run_t5_mlm(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_t5_mlm_flax.py
|
run_t5_mlm_flax.py
|
||||||
@@ -212,9 +201,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_run_ner(self):
|
def test_run_ner(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
||||||
epochs = 7 if get_gpu_count() > 1 else 2
|
epochs = 7 if get_gpu_count() > 1 else 2
|
||||||
|
|
||||||
@@ -245,9 +231,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_run_qa(self):
|
def test_run_qa(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_qa.py
|
run_qa.py
|
||||||
|
|||||||
@@ -86,11 +86,12 @@ def is_cuda_and_apex_available():
|
|||||||
return is_using_cuda and is_apex_available()
|
return is_using_cuda and is_apex_available()
|
||||||
|
|
||||||
|
|
||||||
class ExamplesTestsNoTrainer(TestCasePlus):
|
|
||||||
def test_run_glue_no_trainer(self):
|
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class ExamplesTestsNoTrainer(TestCasePlus):
|
||||||
|
def test_run_glue_no_trainer(self):
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_glue_no_trainer.py
|
run_glue_no_trainer.py
|
||||||
@@ -115,9 +116,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||||
|
|
||||||
def test_run_clm_no_trainer(self):
|
def test_run_clm_no_trainer(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_clm_no_trainer.py
|
run_clm_no_trainer.py
|
||||||
@@ -143,9 +141,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||||
|
|
||||||
def test_run_mlm_no_trainer(self):
|
def test_run_mlm_no_trainer(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_mlm_no_trainer.py
|
run_mlm_no_trainer.py
|
||||||
@@ -164,9 +159,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||||
|
|
||||||
def test_run_ner_no_trainer(self):
|
def test_run_ner_no_trainer(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
||||||
epochs = 7 if get_gpu_count() > 1 else 2
|
epochs = 7 if get_gpu_count() > 1 else 2
|
||||||
|
|
||||||
@@ -193,9 +185,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||||
|
|
||||||
def test_run_squad_no_trainer(self):
|
def test_run_squad_no_trainer(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_qa_no_trainer.py
|
run_qa_no_trainer.py
|
||||||
@@ -220,9 +209,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||||
|
|
||||||
def test_run_swag_no_trainer(self):
|
def test_run_swag_no_trainer(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_swag_no_trainer.py
|
run_swag_no_trainer.py
|
||||||
@@ -244,9 +230,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_run_summarization_no_trainer(self):
|
def test_run_summarization_no_trainer(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_summarization_no_trainer.py
|
run_summarization_no_trainer.py
|
||||||
@@ -273,9 +256,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_run_translation_no_trainer(self):
|
def test_run_translation_no_trainer(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_translation_no_trainer.py
|
run_translation_no_trainer.py
|
||||||
|
|||||||
@@ -97,11 +97,12 @@ def is_cuda_and_apex_available():
|
|||||||
return is_using_cuda and is_apex_available()
|
return is_using_cuda and is_apex_available()
|
||||||
|
|
||||||
|
|
||||||
class ExamplesTests(TestCasePlus):
|
|
||||||
def test_run_glue(self):
|
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class ExamplesTests(TestCasePlus):
|
||||||
|
def test_run_glue(self):
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_glue.py
|
run_glue.py
|
||||||
@@ -130,9 +131,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||||
|
|
||||||
def test_run_clm(self):
|
def test_run_clm(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_clm.py
|
run_clm.py
|
||||||
@@ -187,9 +185,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
self.assertIn('"n_head": 2', cl.out)
|
self.assertIn('"n_head": 2', cl.out)
|
||||||
|
|
||||||
def test_run_mlm(self):
|
def test_run_mlm(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_mlm.py
|
run_mlm.py
|
||||||
@@ -213,9 +208,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
self.assertLess(result["perplexity"], 42)
|
self.assertLess(result["perplexity"], 42)
|
||||||
|
|
||||||
def test_run_ner(self):
|
def test_run_ner(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
||||||
epochs = 7 if get_gpu_count() > 1 else 2
|
epochs = 7 if get_gpu_count() > 1 else 2
|
||||||
|
|
||||||
@@ -247,9 +239,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
self.assertLess(result["eval_loss"], 0.5)
|
self.assertLess(result["eval_loss"], 0.5)
|
||||||
|
|
||||||
def test_run_squad(self):
|
def test_run_squad(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_qa.py
|
run_qa.py
|
||||||
@@ -275,9 +264,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
self.assertGreaterEqual(result["eval_exact"], 30)
|
self.assertGreaterEqual(result["eval_exact"], 30)
|
||||||
|
|
||||||
def test_run_squad_seq2seq(self):
|
def test_run_squad_seq2seq(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_seq2seq_qa.py
|
run_seq2seq_qa.py
|
||||||
@@ -307,9 +293,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
self.assertGreaterEqual(result["eval_exact"], 30)
|
self.assertGreaterEqual(result["eval_exact"], 30)
|
||||||
|
|
||||||
def test_run_swag(self):
|
def test_run_swag(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_swag.py
|
run_swag.py
|
||||||
@@ -333,9 +316,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||||
|
|
||||||
def test_generation(self):
|
def test_generation(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
|
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
|
||||||
|
|
||||||
if is_cuda_and_apex_available():
|
if is_cuda_and_apex_available():
|
||||||
@@ -351,9 +331,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_run_summarization(self):
|
def test_run_summarization(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_summarization.py
|
run_summarization.py
|
||||||
@@ -382,9 +359,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_run_translation(self):
|
def test_run_translation(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_translation.py
|
run_translation.py
|
||||||
@@ -414,9 +388,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
|
|
||||||
@unittest.skip("This is currently broken.")
|
@unittest.skip("This is currently broken.")
|
||||||
def test_run_image_classification(self):
|
def test_run_image_classification(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_image_classification.py
|
run_image_classification.py
|
||||||
@@ -446,9 +417,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||||
|
|
||||||
def test_run_speech_recognition_ctc(self):
|
def test_run_speech_recognition_ctc(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_speech_recognition_ctc.py
|
run_speech_recognition_ctc.py
|
||||||
@@ -479,9 +447,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||||
|
|
||||||
def test_run_speech_recognition_seq2seq(self):
|
def test_run_speech_recognition_seq2seq(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_speech_recognition_seq2seq.py
|
run_speech_recognition_seq2seq.py
|
||||||
@@ -512,9 +477,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||||
|
|
||||||
def test_run_audio_classification(self):
|
def test_run_audio_classification(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_audio_classification.py
|
run_audio_classification.py
|
||||||
@@ -547,9 +509,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||||
|
|
||||||
def test_run_wav2vec2_pretraining(self):
|
def test_run_wav2vec2_pretraining(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_wav2vec2_pretraining_no_trainer.py
|
run_wav2vec2_pretraining_no_trainer.py
|
||||||
@@ -577,9 +536,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
|
|
||||||
@unittest.skip("This is currently broken.")
|
@unittest.skip("This is currently broken.")
|
||||||
def test_run_vit_mae_pretraining(self):
|
def test_run_vit_mae_pretraining(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_mae.py
|
run_mae.py
|
||||||
|
|||||||
@@ -40,14 +40,15 @@ def get_results(output_dir):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
|
||||||
@require_torch_tpu
|
@require_torch_tpu
|
||||||
class TorchXLAExamplesTests(TestCasePlus):
|
class TorchXLAExamplesTests(TestCasePlus):
|
||||||
def test_run_glue(self):
|
def test_run_glue(self):
|
||||||
import xla_spawn
|
import xla_spawn
|
||||||
|
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
./examples/pytorch/text-classification/run_glue.py
|
./examples/pytorch/text-classification/run_glue.py
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ from transformers.testing_utils import (
|
|||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
|
||||||
class RagFinetuneExampleTests(TestCasePlus):
|
class RagFinetuneExampleTests(TestCasePlus):
|
||||||
def _create_dummy_data(self, data_dir):
|
def _create_dummy_data(self, data_dir):
|
||||||
@@ -31,9 +34,6 @@ class RagFinetuneExampleTests(TestCasePlus):
|
|||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
def _run_finetune(self, gpus: int, distributed_retriever: str = "pytorch"):
|
def _run_finetune(self, gpus: int, distributed_retriever: str = "pytorch"):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
output_dir = os.path.join(tmp_dir, "output")
|
output_dir = os.path.join(tmp_dir, "output")
|
||||||
data_dir = os.path.join(tmp_dir, "data")
|
data_dir = os.path.join(tmp_dir, "data")
|
||||||
|
|||||||
Reference in New Issue
Block a user