Fix decorator order (#22708)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -869,8 +869,8 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
# 2. most tests should probably be run on both: zero2 and zero3 configs
|
# 2. most tests should probably be run on both: zero2 and zero3 configs
|
||||||
#
|
#
|
||||||
|
|
||||||
@require_torch_multi_gpu
|
|
||||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||||
|
@require_torch_multi_gpu
|
||||||
def test_basic_distributed(self, stage, dtype):
|
def test_basic_distributed(self, stage, dtype):
|
||||||
self.run_and_check(stage=stage, dtype=dtype, distributed=True)
|
self.run_and_check(stage=stage, dtype=dtype, distributed=True)
|
||||||
|
|
||||||
@@ -900,8 +900,8 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
fp32=True,
|
fp32=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch_multi_gpu
|
|
||||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||||
|
@require_torch_multi_gpu
|
||||||
def test_fp32_distributed(self, stage, dtype):
|
def test_fp32_distributed(self, stage, dtype):
|
||||||
# real model needs too much GPU memory under stage2+fp32, so using tiny random model here -
|
# real model needs too much GPU memory under stage2+fp32, so using tiny random model here -
|
||||||
# therefore no quality checks, just basic completion checks are done
|
# therefore no quality checks, just basic completion checks are done
|
||||||
@@ -941,8 +941,8 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
|
|
||||||
self.do_checks(output_dir, do_train=do_train, do_eval=do_eval)
|
self.do_checks(output_dir, do_train=do_train, do_eval=do_eval)
|
||||||
|
|
||||||
@require_torch_multi_gpu
|
|
||||||
@parameterized.expand(["bf16", "fp16", "fp32"])
|
@parameterized.expand(["bf16", "fp16", "fp32"])
|
||||||
|
@require_torch_multi_gpu
|
||||||
def test_inference(self, dtype):
|
def test_inference(self, dtype):
|
||||||
if dtype == "bf16" and not is_torch_bf16_gpu_available():
|
if dtype == "bf16" and not is_torch_bf16_gpu_available():
|
||||||
self.skipTest("test requires bfloat16 hardware support")
|
self.skipTest("test requires bfloat16 hardware support")
|
||||||
|
|||||||
Reference in New Issue
Block a user