[deepspeed] zero inference (#14253)
* [deepspeed] zero inference * only z3 makes sense for inference * fix and style * docs * rework * fix test * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * responding to suggestions Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -697,11 +697,10 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||
def test_basic_distributed(self, stage):
|
||||
self.run_and_check(stage=stage, distributed=True)
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_do_eval_no_train(self, stage):
|
||||
# we should not fail if train is skipped
|
||||
def test_do_eval_no_train(self):
|
||||
# testing only zero3 since zero2 makes no sense with inference
|
||||
self.run_and_check(
|
||||
stage=stage,
|
||||
stage=ZERO3,
|
||||
eval_steps=1,
|
||||
distributed=False,
|
||||
do_train=False,
|
||||
@@ -755,6 +754,22 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||
|
||||
self.do_checks(output_dir, do_train=do_train, do_eval=do_eval)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@parameterized.expand(["fp16", "fp32"])
|
||||
def test_inference(self, dtype):
|
||||
# this is just inference, so no optimizer should be loaded
|
||||
# it only works for z3 (makes no sense with z1-z2)
|
||||
fp16 = True if dtype == "fp16" else False
|
||||
self.run_and_check(
|
||||
stage=ZERO3,
|
||||
model_name=T5_TINY,
|
||||
distributed=True,
|
||||
do_train=False,
|
||||
do_eval=True,
|
||||
quality_checks=False,
|
||||
fp16=fp16,
|
||||
)
|
||||
|
||||
def do_checks(self, output_dir, do_train=True, do_eval=True, quality_checks=True):
|
||||
|
||||
if do_train:
|
||||
|
||||
Reference in New Issue
Block a user