Add support for ZeRO-2/3 and ZeRO-offload in fairscale (#10354)
* Ass support for ZeRO-2/3 and ZeRO-offload in fairscale * Quality * Rework from review comments * Add doc * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Address review comments Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -64,12 +64,13 @@ def require_apex(test_case):
|
||||
|
||||
|
||||
class TestTrainerExt(TestCasePlus):
|
||||
def run_seq2seq_quick(self, distributed=False, extra_args_str=None):
|
||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
|
||||
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
|
||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, predict_with_generate)
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
assert "eval_bleu" in first_step_stats
|
||||
if predict_with_generate:
|
||||
assert "eval_bleu" in first_step_stats
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
def test_run_seq2seq_no_dist(self):
|
||||
@@ -88,14 +89,28 @@ class TestTrainerExt(TestCasePlus):
|
||||
# test --sharded_ddp w/o --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_ddp_sharded_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp")
|
||||
def test_run_seq2seq_sharded_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
|
||||
|
||||
# test --sharded_ddp w/ --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_ddp_sharded_ddp_fp16(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
|
||||
def test_run_seq2seq_sharded_ddp_fp16(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
|
||||
|
||||
# test --sharded_ddp zero2 w/o --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_fully_sharded_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero2", predict_with_generate=False)
|
||||
|
||||
# test --sharded_ddp zero2 w/ --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
|
||||
self.run_seq2seq_quick(
|
||||
distributed=True, extra_args_str="--sharded_ddp zero2 --fp16", predict_with_generate=False
|
||||
)
|
||||
|
||||
@require_apex
|
||||
def test_run_seq2seq_apex(self):
|
||||
@@ -131,6 +146,7 @@ class TestTrainerExt(TestCasePlus):
|
||||
num_train_epochs: int,
|
||||
distributed: bool = False,
|
||||
extra_args_str: str = None,
|
||||
predict_with_generate: bool = True,
|
||||
):
|
||||
data_dir = self.examples_dir / "test_data/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
@@ -155,7 +171,6 @@ class TestTrainerExt(TestCasePlus):
|
||||
--learning_rate 3e-3
|
||||
--warmup_steps 8
|
||||
--evaluation_strategy steps
|
||||
--predict_with_generate
|
||||
--logging_steps 0
|
||||
--save_steps {str(eval_steps)}
|
||||
--eval_steps {str(eval_steps)}
|
||||
@@ -165,7 +180,11 @@ class TestTrainerExt(TestCasePlus):
|
||||
--task translation
|
||||
--target_lang ro_RO
|
||||
--source_lang en_XX
|
||||
""".split()
|
||||
"""
|
||||
if predict_with_generate:
|
||||
args += "--predict_with_generate"
|
||||
|
||||
args = args.split()
|
||||
|
||||
if extra_args_str is not None:
|
||||
args.extend(extra_args_str.split())
|
||||
|
||||
Reference in New Issue
Block a user