[test] split test into 4 sub-tests to avoid timeout (#12710)
* split the test into 4 sub-tests to avoid timeout * fix decorator order
This commit is contained in:
@@ -19,6 +19,7 @@ import sys
|
|||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from parameterized import parameterized
|
||||||
from transformers.file_utils import is_apex_available
|
from transformers.file_utils import is_apex_available
|
||||||
from transformers.integrations import is_fairscale_available
|
from transformers.integrations import is_fairscale_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@@ -164,48 +165,30 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
# to reproduce the problem set distributed=False
|
# to reproduce the problem set distributed=False
|
||||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
|
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
|
||||||
|
|
||||||
|
@parameterized.expand(["base", "low", "high", "mixed"])
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
def test_trainer_log_level_replica(self):
|
def test_trainer_log_level_replica(self, experiment_id):
|
||||||
log_info_string = "Running training"
|
# as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout
|
||||||
kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False)
|
experiments = dict(
|
||||||
|
|
||||||
# test with the default log_level - should be info and thus log info once
|
# test with the default log_level - should be info and thus log info once
|
||||||
with CaptureStderr() as cl:
|
base=dict(extra_args_str="", n_matches=1),
|
||||||
self.run_seq2seq_quick(
|
|
||||||
**kwargs,
|
|
||||||
extra_args_str="",
|
|
||||||
)
|
|
||||||
n_matches = len(re.findall(log_info_string, cl.err))
|
|
||||||
self.assertEqual(n_matches, 1)
|
|
||||||
|
|
||||||
# test with low log_level and log_level_replica - should be noisy on all processes
|
# test with low log_level and log_level_replica - should be noisy on all processes
|
||||||
# now the info string should appear twice on 2 processes
|
# now the info string should appear twice on 2 processes
|
||||||
with CaptureStderr() as cl:
|
low=dict(extra_args_str="--log_level debug --log_level_replica debug", n_matches=2),
|
||||||
self.run_seq2seq_quick(
|
|
||||||
**kwargs,
|
|
||||||
extra_args_str="--log_level debug --log_level_replica debug",
|
|
||||||
)
|
|
||||||
n_matches = len(re.findall(log_info_string, cl.err))
|
|
||||||
self.assertEqual(n_matches, 2)
|
|
||||||
|
|
||||||
# test with high log_level and low log_level_replica
|
# test with high log_level and low log_level_replica
|
||||||
# now the info string should appear once only on the replica
|
# now the info string should appear once only on the replica
|
||||||
with CaptureStderr() as cl:
|
high=dict(extra_args_str="--log_level error --log_level_replica debug", n_matches=1),
|
||||||
self.run_seq2seq_quick(
|
|
||||||
**kwargs,
|
|
||||||
extra_args_str="--log_level error --log_level_replica debug",
|
|
||||||
)
|
|
||||||
n_matches = len(re.findall(log_info_string, cl.err))
|
|
||||||
self.assertEqual(n_matches, 1)
|
|
||||||
|
|
||||||
# test with high log_level and log_level_replica - should be quiet on all processes
|
# test with high log_level and log_level_replica - should be quiet on all processes
|
||||||
with CaptureStderr() as cl:
|
mixed=dict(extra_args_str="--log_level error --log_level_replica error", n_matches=0),
|
||||||
self.run_seq2seq_quick(
|
|
||||||
**kwargs,
|
|
||||||
extra_args_str="--log_level error --log_level_replica error",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
data = experiments[experiment_id]
|
||||||
|
kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False)
|
||||||
|
log_info_string = "Running training"
|
||||||
|
with CaptureStderr() as cl:
|
||||||
|
self.run_seq2seq_quick(**kwargs, extra_args_str=data["extra_args_str"])
|
||||||
n_matches = len(re.findall(log_info_string, cl.err))
|
n_matches = len(re.findall(log_info_string, cl.err))
|
||||||
self.assertEqual(n_matches, 0)
|
self.assertEqual(n_matches, data["n_matches"])
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_run_seq2seq_slow(self):
|
def test_run_seq2seq_slow(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user