Fairscale FSDP fix model save (#10596)
* Hotfix fairscale FSDP * Evaluation works * Save on process zero
This commit is contained in:
@@ -66,7 +66,7 @@ def require_apex(test_case):
|
|||||||
|
|
||||||
|
|
||||||
class TestTrainerExt(TestCasePlus):
|
class TestTrainerExt(TestCasePlus):
|
||||||
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
|
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True):
|
||||||
output_dir = self.run_trainer(
|
output_dir = self.run_trainer(
|
||||||
eval_steps=1,
|
eval_steps=1,
|
||||||
max_len=12,
|
max_len=12,
|
||||||
@@ -116,14 +116,12 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
# test --sharded_ddp zero_dp_2 w/o --fp16
|
# test --sharded_ddp zero_dp_2 w/o --fp16
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
@require_fairscale
|
@require_fairscale
|
||||||
@unittest.skip("XXX: Fixme: hanging")
|
|
||||||
def test_run_seq2seq_fully_sharded_ddp(self):
|
def test_run_seq2seq_fully_sharded_ddp(self):
|
||||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
|
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
|
||||||
|
|
||||||
# test --sharded_ddp zero_dp_2 w/ --fp16
|
# test --sharded_ddp zero_dp_2 w/ --fp16
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
@require_fairscale
|
@require_fairscale
|
||||||
@unittest.skip("XXX: Fixme: hanging")
|
|
||||||
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
|
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
|
||||||
self.run_seq2seq_quick(
|
self.run_seq2seq_quick(
|
||||||
distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
|
distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
|
||||||
@@ -206,8 +204,8 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
--warmup_steps 8
|
--warmup_steps 8
|
||||||
--evaluation_strategy steps
|
--evaluation_strategy steps
|
||||||
--logging_steps 0
|
--logging_steps 0
|
||||||
--save_steps {str(eval_steps)}
|
|
||||||
--eval_steps {str(eval_steps)}
|
--eval_steps {str(eval_steps)}
|
||||||
|
--save_steps {str(eval_steps)}
|
||||||
--group_by_length
|
--group_by_length
|
||||||
--label_smoothing_factor 0.1
|
--label_smoothing_factor 0.1
|
||||||
--adafactor
|
--adafactor
|
||||||
|
|||||||
@@ -1497,11 +1497,14 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
self._save_tpu(output_dir)
|
self._save_tpu(output_dir)
|
||||||
else:
|
elif (
|
||||||
|
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
||||||
|
):
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
if self.is_world_process_zero():
|
if self.is_world_process_zero():
|
||||||
|
self._save(output_dir, state_dict=state_dict)
|
||||||
|
elif self.is_world_process_zero():
|
||||||
self._save(output_dir)
|
self._save(output_dir)
|
||||||
if self.args.local_rank != -1:
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
def _save_tpu(self, output_dir: Optional[str] = None):
|
def _save_tpu(self, output_dir: Optional[str] = None):
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
@@ -1531,7 +1534,7 @@ class Trainer:
|
|||||||
if self.tokenizer is not None and self.is_world_process_zero():
|
if self.tokenizer is not None and self.is_world_process_zero():
|
||||||
self.tokenizer.save_pretrained(output_dir)
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None):
|
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||||
# If we are executing this function, we are the process zero, so we don't check for that.
|
# If we are executing this function, we are the process zero, so we don't check for that.
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
@@ -1540,13 +1543,16 @@ class Trainer:
|
|||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
if not isinstance(self.model, PreTrainedModel):
|
if not isinstance(self.model, PreTrainedModel):
|
||||||
if isinstance(unwrap_model(self.model), PreTrainedModel):
|
if isinstance(unwrap_model(self.model), PreTrainedModel):
|
||||||
unwrap_model(self.model).save_pretrained(output_dir, state_dict=self.model.state_dict())
|
if state_dict is None:
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
|
||||||
else:
|
else:
|
||||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||||
|
if state_dict is None:
|
||||||
state_dict = self.model.state_dict()
|
state_dict = self.model.state_dict()
|
||||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
else:
|
else:
|
||||||
self.model.save_pretrained(output_dir)
|
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
||||||
if self.tokenizer is not None:
|
if self.tokenizer is not None:
|
||||||
self.tokenizer.save_pretrained(output_dir)
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user