Fixing case in which Trainer hung while saving model in distributed training (#7365)

* remote debugging

* remote debugging

* moved _store_flos call

* moved _store_flos call

* moved _store_flos call

* removed debugging artefacts
This commit is contained in:
Teven
2020-09-24 15:56:40 +02:00
committed by GitHub
parent 0ccb6f5c6d
commit 7dfdf793bb

View File

@@ -812,6 +812,7 @@ class Trainer:
checkpoint_folder += f"-run-{run_id}" checkpoint_folder += f"-run-{run_id}"
output_dir = os.path.join(self.args.output_dir, checkpoint_folder) output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
self.store_flos()
self.save_model(output_dir) self.save_model(output_dir)
if self.is_world_process_zero(): if self.is_world_process_zero():
@@ -1151,7 +1152,6 @@ class Trainer:
raise ValueError("Trainer.model appears to not be a PreTrainedModel") raise ValueError("Trainer.model appears to not be a PreTrainedModel")
xm.rendezvous("saving_checkpoint") xm.rendezvous("saving_checkpoint")
self._store_flos()
self.model.save_pretrained(output_dir) self.model.save_pretrained(output_dir)
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
@@ -1164,7 +1164,6 @@ class Trainer:
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel): if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel") raise ValueError("Trainer.model appears to not be a PreTrainedModel")
self._store_flos()
self.model.save_pretrained(output_dir) self.model.save_pretrained(output_dir)
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
@@ -1175,7 +1174,7 @@ class Trainer:
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
) )
def _store_flos(self): def store_flos(self):
# Storing the number of floating-point operations that went into the model # Storing the number of floating-point operations that went into the model
if self.total_flos is not None: if self.total_flos is not None:
if self.args.local_rank != -1: if self.args.local_rank != -1: