[deepspeed] fix a bug in a test (#15493)
* [deepspeed] fix a bug in a test * consistency
This commit is contained in:
@@ -25,6 +25,7 @@ from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available
|
|||||||
from transformers.file_utils import WEIGHTS_NAME
|
from transformers.file_utils import WEIGHTS_NAME
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
|
CaptureStd,
|
||||||
CaptureStderr,
|
CaptureStderr,
|
||||||
ExtendSysPath,
|
ExtendSysPath,
|
||||||
LoggingLevel,
|
LoggingLevel,
|
||||||
@@ -972,7 +973,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||||
with CaptureStderr() as cs:
|
with CaptureStderr() as cs:
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
assert "Detected DeepSpeed ZeRO-3" in cs.err
|
self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)
|
||||||
|
|
||||||
@parameterized.expand(stages)
|
@parameterized.expand(stages)
|
||||||
def test_load_best_model(self, stage):
|
def test_load_best_model(self, stage):
|
||||||
@@ -1008,14 +1009,14 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
""".split()
|
""".split()
|
||||||
args.extend(["--source_prefix", "translate English to Romanian: "])
|
args.extend(["--source_prefix", "translate English to Romanian: "])
|
||||||
|
|
||||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split()
|
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
|
||||||
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
|
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
|
||||||
launcher = get_launcher(distributed=False)
|
launcher = get_launcher(distributed=False)
|
||||||
|
|
||||||
cmd = launcher + script + args + ds_args
|
cmd = launcher + script + args + ds_args
|
||||||
# keep for quick debug
|
# keep for quick debug
|
||||||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||||
with CaptureStderr() as cs:
|
with CaptureStd() as cs:
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
# enough to test it didn't fail
|
# enough to test it didn't fail
|
||||||
assert "Detected DeepSpeed ZeRO-3" in cs.err
|
self.assertIn("DeepSpeed info", cs.out)
|
||||||
|
|||||||
Reference in New Issue
Block a user