From b29e2dcaff114762e65eaea739ba1076fc5d1c84 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 28 Feb 2023 22:24:14 +0100 Subject: [PATCH] Fix flaky test for log level (#21776) * Fix flaky test for log level * Fix other flaky test --- tests/trainer/test_trainer.py | 10 ++++++---- tests/utils/test_logging.py | 1 + 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1680696813..2ff81e5fe7 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1093,18 +1093,20 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]])) self.assertTrue(np.all(seen[expected.shape[0] :] == -100)) - # FIXME: sgugger - @unittest.skip(reason="might be flaky after PR #21700. Skip for now.") def test_log_level(self): # testing only --log_level (--log_level_replica requires multiple gpus and DDP and is tested elsewhere) logger = logging.get_logger() log_info_string = "Running training" - # test with the default log_level - should be warning and thus not log on the main process + # test with the default log_level - should be the same as before and thus we test depending on is_info + is_info = logging.get_verbosity() <= 20 with CaptureLogger(logger) as cl: trainer = get_regression_trainer() trainer.train() - self.assertNotIn(log_info_string, cl.out) + if is_info: + self.assertIn(log_info_string, cl.out) + else: + self.assertNotIn(log_info_string, cl.out) # test with low log_level - lower than info with CaptureLogger(logger) as cl: diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 5f0b5fe325..c9bbb82436 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -109,6 +109,7 @@ class HfArgumentParserTest(unittest.TestCase): def test_advisory_warnings(self): # testing `logger.warning_advice()` + transformers.utils.logging._reset_library_root_logger() logger = logging.get_logger("transformers.models.bart.tokenization_bart") msg = "Testing 1, 2, 3"