Fairscale FSDP fix model save (#10596)

* Hotfix fairscale FSDP

* Evaluation works

* Save on process zero
This commit is contained in:
Sylvain Gugger
2021-03-09 14:42:07 -05:00
committed by GitHub
parent ac17f71159
commit 0d909f6bd8
2 changed files with 19 additions and 15 deletions

View File

@@ -1497,11 +1497,14 @@ class Trainer:
"""
if is_torch_tpu_available():
self._save_tpu(output_dir)
else:
elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
):
state_dict = self.model.state_dict()
if self.is_world_process_zero():
self._save(output_dir)
if self.args.local_rank != -1:
dist.barrier()
self._save(output_dir, state_dict=state_dict)
elif self.is_world_process_zero():
self._save(output_dir)
def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
@@ -1531,7 +1534,7 @@ class Trainer:
if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir)
def _save(self, output_dir: Optional[str] = None):
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
@@ -1540,13 +1543,16 @@ class Trainer:
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(output_dir, state_dict=self.model.state_dict())
if state_dict is None:
state_dict = self.model.state_dict()
unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
if state_dict is None:
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir)
self.model.save_pretrained(output_dir, state_dict=state_dict)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)