Fairscale FSDP fix model save (#10596)
* Hotfix fairscale FSDP * Evaluation works * Save on process zero
This commit is contained in:
@@ -66,7 +66,7 @@ def require_apex(test_case):
|
||||
|
||||
|
||||
class TestTrainerExt(TestCasePlus):
|
||||
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
|
||||
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True):
|
||||
output_dir = self.run_trainer(
|
||||
eval_steps=1,
|
||||
max_len=12,
|
||||
@@ -83,9 +83,9 @@ class TestTrainerExt(TestCasePlus):
|
||||
if predict_with_generate:
|
||||
assert "eval_bleu" in first_step_stats
|
||||
|
||||
last_step_stats = eval_metrics[-1]
|
||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||
assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
|
||||
last_step_stats = eval_metrics[-1]
|
||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||
assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
def test_run_seq2seq_no_dist(self):
|
||||
@@ -116,14 +116,12 @@ class TestTrainerExt(TestCasePlus):
|
||||
# test --sharded_ddp zero_dp_2 w/o --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
@unittest.skip("XXX: Fixme: hanging")
|
||||
def test_run_seq2seq_fully_sharded_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
|
||||
|
||||
# test --sharded_ddp zero_dp_2 w/ --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
@unittest.skip("XXX: Fixme: hanging")
|
||||
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
|
||||
self.run_seq2seq_quick(
|
||||
distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
|
||||
@@ -206,8 +204,8 @@ class TestTrainerExt(TestCasePlus):
|
||||
--warmup_steps 8
|
||||
--evaluation_strategy steps
|
||||
--logging_steps 0
|
||||
--save_steps {str(eval_steps)}
|
||||
--eval_steps {str(eval_steps)}
|
||||
--save_steps {str(eval_steps)}
|
||||
--group_by_length
|
||||
--label_smoothing_factor 0.1
|
||||
--adafactor
|
||||
|
||||
Reference in New Issue
Block a user