From a18a17d2b6357321279190963765085a0ef4d466 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 14 Jul 2021 13:04:58 -0700 Subject: [PATCH] [test] split test into 4 sub-tests to avoid timeout (#12710) * split the test into 4 sub-tests to avoid timeout * fix decorator order --- tests/extended/test_trainer_ext.py | 59 +++++++++++------------------- 1 file changed, 21 insertions(+), 38 deletions(-) diff --git a/tests/extended/test_trainer_ext.py b/tests/extended/test_trainer_ext.py index dca3604e13..eb225c16f5 100644 --- a/tests/extended/test_trainer_ext.py +++ b/tests/extended/test_trainer_ext.py @@ -19,6 +19,7 @@ import sys import unittest from unittest.mock import patch +from parameterized import parameterized from transformers.file_utils import is_apex_available from transformers.integrations import is_fairscale_available from transformers.testing_utils import ( @@ -164,48 +165,30 @@ class TestTrainerExt(TestCasePlus): # to reproduce the problem set distributed=False self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex") + @parameterized.expand(["base", "low", "high", "mixed"]) @require_torch_multi_gpu - def test_trainer_log_level_replica(self): - log_info_string = "Running training" + def test_trainer_log_level_replica(self, experiment_id): + # as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout + experiments = dict( + # test with the default log_level - should be info and thus log info once + base=dict(extra_args_str="", n_matches=1), + # test with low log_level and log_level_replica - should be noisy on all processes + # now the info string should appear twice on 2 processes + low=dict(extra_args_str="--log_level debug --log_level_replica debug", n_matches=2), + # test with high log_level and low log_level_replica + # now the info string should appear once only on the replica + high=dict(extra_args_str="--log_level error --log_level_replica debug", n_matches=1), + # test with high log_level and log_level_replica - should be quiet on all processes + mixed=dict(extra_args_str="--log_level error --log_level_replica error", n_matches=0), + ) + + data = experiments[experiment_id] kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False) - - # test with the default log_level - should be info and thus log info once + log_info_string = "Running training" with CaptureStderr() as cl: - self.run_seq2seq_quick( - **kwargs, - extra_args_str="", - ) + self.run_seq2seq_quick(**kwargs, extra_args_str=data["extra_args_str"]) n_matches = len(re.findall(log_info_string, cl.err)) - self.assertEqual(n_matches, 1) - - # test with low log_level and log_level_replica - should be noisy on all processes - # now the info string should appear twice on 2 processes - with CaptureStderr() as cl: - self.run_seq2seq_quick( - **kwargs, - extra_args_str="--log_level debug --log_level_replica debug", - ) - n_matches = len(re.findall(log_info_string, cl.err)) - self.assertEqual(n_matches, 2) - - # test with high log_level and low log_level_replica - # now the info string should appear once only on the replica - with CaptureStderr() as cl: - self.run_seq2seq_quick( - **kwargs, - extra_args_str="--log_level error --log_level_replica debug", - ) - n_matches = len(re.findall(log_info_string, cl.err)) - self.assertEqual(n_matches, 1) - - # test with high log_level and log_level_replica - should be quiet on all processes - with CaptureStderr() as cl: - self.run_seq2seq_quick( - **kwargs, - extra_args_str="--log_level error --log_level_replica error", - ) - n_matches = len(re.findall(log_info_string, cl.err)) - self.assertEqual(n_matches, 0) + self.assertEqual(n_matches, data["n_matches"]) @slow def test_run_seq2seq_slow(self):