Fairscale FSDP fix model save (#10596)

* Hotfix fairscale FSDP

* Evaluation works

* Save on process zero
This commit is contained in:
Sylvain Gugger
2021-03-09 14:42:07 -05:00
committed by GitHub
parent ac17f71159
commit 0d909f6bd8
2 changed files with 19 additions and 15 deletions

View File

@@ -66,7 +66,7 @@ def require_apex(test_case):
class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True):
output_dir = self.run_trainer(
eval_steps=1,
max_len=12,
@@ -83,9 +83,9 @@ class TestTrainerExt(TestCasePlus):
if predict_with_generate:
assert "eval_bleu" in first_step_stats
last_step_stats = eval_metrics[-1]
assert isinstance(last_step_stats["eval_bleu"], float)
assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
last_step_stats = eval_metrics[-1]
assert isinstance(last_step_stats["eval_bleu"], float)
assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
@require_torch_non_multi_gpu
def test_run_seq2seq_no_dist(self):
@@ -116,14 +116,12 @@ class TestTrainerExt(TestCasePlus):
# test --sharded_ddp zero_dp_2 w/o --fp16
@require_torch_multi_gpu
@require_fairscale
@unittest.skip("XXX: Fixme: hanging")
def test_run_seq2seq_fully_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
# test --sharded_ddp zero_dp_2 w/ --fp16
@require_torch_multi_gpu
@require_fairscale
@unittest.skip("XXX: Fixme: hanging")
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
self.run_seq2seq_quick(
distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
@@ -206,8 +204,8 @@ class TestTrainerExt(TestCasePlus):
--warmup_steps 8
--evaluation_strategy steps
--logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)}
--save_steps {str(eval_steps)}
--group_by_length
--label_smoothing_factor 0.1
--adafactor