Only log total_flos at the end of training (#7981)

* Only log total_flos at the end of training

* Fix test
This commit is contained in:
Sylvain Gugger
2020-10-22 14:26:55 -04:00
committed by GitHub
parent ff65beafa3
commit 06fc3954a1
3 changed files with 6 additions and 10 deletions

View File

@@ -830,6 +830,10 @@ class Trainer:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
if self._total_flos is not None:
self.store_flos()
self.log({"total_flos": self.state.total_flos})
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
@@ -1013,9 +1017,6 @@ class Trainer:
return self._log(logs)
if self.state.epoch is not None:
logs["epoch"] = self.state.epoch
if self._total_flos is not None:
self.store_flos()
logs["total_flos"] = self.state.total_flos
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
output = {**logs, **{"step": self.state.global_step}}