[deepspeed] fix a bug in a test (#15493)
* [deepspeed] fix a bug in a test * consistency
This commit is contained in:
@@ -25,6 +25,7 @@ from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
CaptureStd,
|
||||
CaptureStderr,
|
||||
ExtendSysPath,
|
||||
LoggingLevel,
|
||||
@@ -972,7 +973,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||
with CaptureStderr() as cs:
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
assert "Detected DeepSpeed ZeRO-3" in cs.err
|
||||
self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_load_best_model(self, stage):
|
||||
@@ -1008,14 +1009,14 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||
""".split()
|
||||
args.extend(["--source_prefix", "translate English to Romanian: "])
|
||||
|
||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split()
|
||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
|
||||
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
|
||||
launcher = get_launcher(distributed=False)
|
||||
|
||||
cmd = launcher + script + args + ds_args
|
||||
# keep for quick debug
|
||||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||
with CaptureStderr() as cs:
|
||||
with CaptureStd() as cs:
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# enough to test it didn't fail
|
||||
assert "Detected DeepSpeed ZeRO-3" in cs.err
|
||||
self.assertIn("DeepSpeed info", cs.out)
|
||||
|
||||
Reference in New Issue
Block a user