From 7dfdf793bb5e3a865f33ed597b10fc4526364af9 Mon Sep 17 00:00:00 2001 From: Teven Date: Thu, 24 Sep 2020 15:56:40 +0200 Subject: [PATCH] Fixing case in which `Trainer` hung while saving model in distributed training (#7365) * remote debugging * remote debugging * moved _store_flos call * moved _store_flos call * moved _store_flos call * removed debugging artefacts --- src/transformers/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6e97e13d04..5158e5cbbf 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -812,6 +812,7 @@ class Trainer: checkpoint_folder += f"-run-{run_id}" output_dir = os.path.join(self.args.output_dir, checkpoint_folder) + self.store_flos() self.save_model(output_dir) if self.is_world_process_zero(): @@ -1151,7 +1152,6 @@ class Trainer: raise ValueError("Trainer.model appears to not be a PreTrainedModel") xm.rendezvous("saving_checkpoint") - self._store_flos() self.model.save_pretrained(output_dir) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) @@ -1164,7 +1164,6 @@ class Trainer: # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, PreTrainedModel): raise ValueError("Trainer.model appears to not be a PreTrainedModel") - self._store_flos() self.model.save_pretrained(output_dir) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) @@ -1175,7 +1174,7 @@ class Trainer: self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False ) - def _store_flos(self): + def store_flos(self): # Storing the number of floating-point operations that went into the model if self.total_flos is not None: if self.args.local_rank != -1: