Fairscale FSDP fix model save (#10596)
* Hotfix fairscale FSDP * Evaluation works * Save on process zero
This commit is contained in:
@@ -1497,11 +1497,14 @@ class Trainer:
|
||||
"""
|
||||
if is_torch_tpu_available():
|
||||
self._save_tpu(output_dir)
|
||||
else:
|
||||
elif (
|
||||
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
||||
):
|
||||
state_dict = self.model.state_dict()
|
||||
if self.is_world_process_zero():
|
||||
self._save(output_dir)
|
||||
if self.args.local_rank != -1:
|
||||
dist.barrier()
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
elif self.is_world_process_zero():
|
||||
self._save(output_dir)
|
||||
|
||||
def _save_tpu(self, output_dir: Optional[str] = None):
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
@@ -1531,7 +1534,7 @@ class Trainer:
|
||||
if self.tokenizer is not None and self.is_world_process_zero():
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None):
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
# If we are executing this function, we are the process zero, so we don't check for that.
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@@ -1540,13 +1543,16 @@ class Trainer:
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
if not isinstance(self.model, PreTrainedModel):
|
||||
if isinstance(unwrap_model(self.model), PreTrainedModel):
|
||||
unwrap_model(self.model).save_pretrained(output_dir, state_dict=self.model.state_dict())
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
|
||||
else:
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
state_dict = self.model.state_dict()
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
self.model.save_pretrained(output_dir)
|
||||
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user