TPU needs a rendezvous (#4339)

This commit is contained in:
Lysandre Debut
2020-05-14 08:59:52 -04:00
committed by GitHub
parent 94cb73c2d2
commit ef46ccb05c

View File

@@ -542,9 +542,28 @@ class Trainer:
Will only save from the master process.
"""
if self.is_world_master():
if is_tpu_available():
self._save_tpu(output_dir)
elif self.is_world_master():
self._save(output_dir)
def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
logger.info("Saving model checkpoint to %s", output_dir)
if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
xm.rendezvous("saving_checkpoint")
self.model.save_pretrained(output_dir)
def _save(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)