Add support for ZeRO-2/3 and ZeRO-offload in fairscale (#10354)

* Ass support for ZeRO-2/3 and ZeRO-offload in fairscale

* Quality

* Rework from review comments

* Add doc

* Apply suggestions from code review

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Address review comments

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger
2021-02-25 11:07:53 -05:00
committed by GitHub
parent 88cc26dcd1
commit 9d14be5c20
5 changed files with 193 additions and 46 deletions

View File

@@ -64,12 +64,13 @@ def require_apex(test_case):
class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick(self, distributed=False, extra_args_str=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, predict_with_generate)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats
if predict_with_generate:
assert "eval_bleu" in first_step_stats
@require_torch_non_multi_gpu
def test_run_seq2seq_no_dist(self):
@@ -88,14 +89,28 @@ class TestTrainerExt(TestCasePlus):
# test --sharded_ddp w/o --fp16
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_ddp_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp")
def test_run_seq2seq_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
# test --sharded_ddp w/ --fp16
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_ddp_sharded_ddp_fp16(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
def test_run_seq2seq_sharded_ddp_fp16(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
# test --sharded_ddp zero2 w/o --fp16
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero2", predict_with_generate=False)
# test --sharded_ddp zero2 w/ --fp16
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
self.run_seq2seq_quick(
distributed=True, extra_args_str="--sharded_ddp zero2 --fp16", predict_with_generate=False
)
@require_apex
def test_run_seq2seq_apex(self):
@@ -131,6 +146,7 @@ class TestTrainerExt(TestCasePlus):
num_train_epochs: int,
distributed: bool = False,
extra_args_str: str = None,
predict_with_generate: bool = True,
):
data_dir = self.examples_dir / "test_data/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
@@ -155,7 +171,6 @@ class TestTrainerExt(TestCasePlus):
--learning_rate 3e-3
--warmup_steps 8
--evaluation_strategy steps
--predict_with_generate
--logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)}
@@ -165,7 +180,11 @@ class TestTrainerExt(TestCasePlus):
--task translation
--target_lang ro_RO
--source_lang en_XX
""".split()
"""
if predict_with_generate:
args += "--predict_with_generate"
args = args.split()
if extra_args_str is not None:
args.extend(extra_args_str.split())