From 4f5faaf04407d4d55f75ecfe9246b2b952b87dfc Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 3 Feb 2022 08:55:45 -0800 Subject: [PATCH] [deepspeed] fix a bug in a test (#15493) * [deepspeed] fix a bug in a test * consistency --- tests/deepspeed/test_deepspeed.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 8aaf789b97..5e27b6e698 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -25,6 +25,7 @@ from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available from transformers.file_utils import WEIGHTS_NAME from transformers.testing_utils import ( CaptureLogger, + CaptureStd, CaptureStderr, ExtendSysPath, LoggingLevel, @@ -972,7 +973,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus): # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die with CaptureStderr() as cs: execute_subprocess_async(cmd, env=self.get_env()) - assert "Detected DeepSpeed ZeRO-3" in cs.err + self.assertIn("Detected DeepSpeed ZeRO-3", cs.err) @parameterized.expand(stages) def test_load_best_model(self, stage): @@ -1008,14 +1009,14 @@ class TestDeepSpeedWithLauncher(TestCasePlus): """.split() args.extend(["--source_prefix", "translate English to Romanian: "]) - ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split() + ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split() script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"] launcher = get_launcher(distributed=False) cmd = launcher + script + args + ds_args # keep for quick debug # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die - with CaptureStderr() as cs: + with CaptureStd() as cs: execute_subprocess_async(cmd, env=self.get_env()) # enough to test it didn't fail - assert "Detected DeepSpeed ZeRO-3" in cs.err + self.assertIn("DeepSpeed info", cs.out)