From 0d909f6bd8ca0bc1ec8f42e089b64b4fffc4d230 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 9 Mar 2021 14:42:07 -0500 Subject: [PATCH] Fairscale FSDP fix model save (#10596) * Hotfix fairscale FSDP * Evaluation works * Save on process zero --- examples/tests/trainer/test_trainer_ext.py | 12 +++++------- src/transformers/trainer.py | 22 ++++++++++++++-------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/examples/tests/trainer/test_trainer_ext.py b/examples/tests/trainer/test_trainer_ext.py index b5c97f5a94..38c714709f 100644 --- a/examples/tests/trainer/test_trainer_ext.py +++ b/examples/tests/trainer/test_trainer_ext.py @@ -66,7 +66,7 @@ def require_apex(test_case): class TestTrainerExt(TestCasePlus): - def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True): + def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True): output_dir = self.run_trainer( eval_steps=1, max_len=12, @@ -83,9 +83,9 @@ class TestTrainerExt(TestCasePlus): if predict_with_generate: assert "eval_bleu" in first_step_stats - last_step_stats = eval_metrics[-1] - assert isinstance(last_step_stats["eval_bleu"], float) - assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`" + last_step_stats = eval_metrics[-1] + assert isinstance(last_step_stats["eval_bleu"], float) + assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`" @require_torch_non_multi_gpu def test_run_seq2seq_no_dist(self): @@ -116,14 +116,12 @@ class TestTrainerExt(TestCasePlus): # test --sharded_ddp zero_dp_2 w/o --fp16 @require_torch_multi_gpu @require_fairscale - @unittest.skip("XXX: Fixme: hanging") def test_run_seq2seq_fully_sharded_ddp(self): self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False) # test --sharded_ddp zero_dp_2 w/ --fp16 @require_torch_multi_gpu @require_fairscale - @unittest.skip("XXX: Fixme: hanging") def test_run_seq2seq_fully_sharded_ddp_fp16(self): self.run_seq2seq_quick( distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False @@ -206,8 +204,8 @@ class TestTrainerExt(TestCasePlus): --warmup_steps 8 --evaluation_strategy steps --logging_steps 0 - --save_steps {str(eval_steps)} --eval_steps {str(eval_steps)} + --save_steps {str(eval_steps)} --group_by_length --label_smoothing_factor 0.1 --adafactor diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index aaf9c1e627..0ecf598697 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1497,11 +1497,14 @@ class Trainer: """ if is_torch_tpu_available(): self._save_tpu(output_dir) - else: + elif ( + ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp + ): + state_dict = self.model.state_dict() if self.is_world_process_zero(): - self._save(output_dir) - if self.args.local_rank != -1: - dist.barrier() + self._save(output_dir, state_dict=state_dict) + elif self.is_world_process_zero(): + self._save(output_dir) def _save_tpu(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir @@ -1531,7 +1534,7 @@ class Trainer: if self.tokenizer is not None and self.is_world_process_zero(): self.tokenizer.save_pretrained(output_dir) - def _save(self, output_dir: Optional[str] = None): + def _save(self, output_dir: Optional[str] = None, state_dict=None): # If we are executing this function, we are the process zero, so we don't check for that. output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) @@ -1540,13 +1543,16 @@ class Trainer: # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, PreTrainedModel): if isinstance(unwrap_model(self.model), PreTrainedModel): - unwrap_model(self.model).save_pretrained(output_dir, state_dict=self.model.state_dict()) + if state_dict is None: + state_dict = self.model.state_dict() + unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") - state_dict = self.model.state_dict() + if state_dict is None: + state_dict = self.model.state_dict() torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - self.model.save_pretrained(output_dir) + self.model.save_pretrained(output_dir, state_dict=state_dict) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir)