From ef46ccb05c601f413a774d43524591816406778d Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 14 May 2020 08:59:52 -0400 Subject: [PATCH] TPU needs a rendezvous (#4339) --- src/transformers/trainer.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 00300279c4..a3b630011d 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -542,9 +542,28 @@ class Trainer: Will only save from the master process. """ - if self.is_world_master(): + + if is_tpu_available(): + self._save_tpu(output_dir) + elif self.is_world_master(): self._save(output_dir) + def _save_tpu(self, output_dir: Optional[str] = None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + logger.info("Saving model checkpoint to %s", output_dir) + + if xm.is_master_ordinal(): + os.makedirs(output_dir, exist_ok=True) + torch.save(self.args, os.path.join(output_dir, "training_args.bin")) + + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + if not isinstance(self.model, PreTrainedModel): + raise ValueError("Trainer.model appears to not be a PreTrainedModel") + + xm.rendezvous("saving_checkpoint") + self.model.save_pretrained(output_dir) + def _save(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True)