Only log total_flos at the end of training (#7981)
* Only log total_flos at the end of training * Fix test
This commit is contained in:
@@ -830,6 +830,10 @@ class Trainer:
|
||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
if self._total_flos is not None:
|
||||
self.store_flos()
|
||||
self.log({"total_flos": self.state.total_flos})
|
||||
|
||||
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
|
||||
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
|
||||
@@ -1013,9 +1017,6 @@ class Trainer:
|
||||
return self._log(logs)
|
||||
if self.state.epoch is not None:
|
||||
logs["epoch"] = self.state.epoch
|
||||
if self._total_flos is not None:
|
||||
self.store_flos()
|
||||
logs["total_flos"] = self.state.total_flos
|
||||
|
||||
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||
output = {**logs, **{"step": self.state.global_step}}
|
||||
|
||||
Reference in New Issue
Block a user