[Trainer] add train loss and flops metrics reports (#11980)
* add train loss and flops metrics reports * consistency * add train_loss to skip keys * restore on_train_end call timing
This commit is contained in:
@@ -1362,20 +1362,24 @@ class Trainer:
|
|||||||
self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
|
self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# add remaining tr_loss
|
||||||
|
self._total_loss_scalar += tr_loss.item()
|
||||||
|
train_loss = self._total_loss_scalar / self.state.global_step
|
||||||
|
|
||||||
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
|
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
|
||||||
self.store_flos()
|
self.store_flos()
|
||||||
metrics["total_flos"] = self.state.total_flos
|
metrics["total_flos"] = self.state.total_flos
|
||||||
self.log(metrics)
|
metrics["train_loss"] = train_loss
|
||||||
|
|
||||||
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
|
||||||
# add remaining tr_loss
|
|
||||||
self._total_loss_scalar += tr_loss.item()
|
|
||||||
|
|
||||||
self.is_in_train = False
|
self.is_in_train = False
|
||||||
|
|
||||||
self._memory_tracker.stop_and_update_metrics(metrics)
|
self._memory_tracker.stop_and_update_metrics(metrics)
|
||||||
|
|
||||||
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)
|
self.log(metrics)
|
||||||
|
|
||||||
|
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
||||||
|
|
||||||
|
return TrainOutput(self.state.global_step, train_loss, metrics)
|
||||||
|
|
||||||
def _load_state_dict_in_model(self, state_dict):
|
def _load_state_dict_in_model(self, state_dict):
|
||||||
load_result = self.model.load_state_dict(state_dict, strict=False)
|
load_result = self.model.load_state_dict(state_dict, strict=False)
|
||||||
|
|||||||
@@ -311,13 +311,11 @@ class TrainerIntegrationCommon:
|
|||||||
log_history = state.pop("log_history", None)
|
log_history = state.pop("log_history", None)
|
||||||
log_history1 = state1.pop("log_history", None)
|
log_history1 = state1.pop("log_history", None)
|
||||||
self.assertEqual(state, state1)
|
self.assertEqual(state, state1)
|
||||||
|
skip_log_keys = ["train_runtime", "train_samples_per_second", "train_steps_per_second", "train_loss"]
|
||||||
for log, log1 in zip(log_history, log_history1):
|
for log, log1 in zip(log_history, log_history1):
|
||||||
_ = log.pop("train_runtime", None)
|
for key in skip_log_keys:
|
||||||
_ = log1.pop("train_runtime", None)
|
_ = log.pop(key, None)
|
||||||
_ = log.pop("train_samples_per_second", None)
|
_ = log1.pop(key, None)
|
||||||
_ = log1.pop("train_samples_per_second", None)
|
|
||||||
_ = log.pop("train_steps_per_second", None)
|
|
||||||
_ = log1.pop("train_steps_per_second", None)
|
|
||||||
self.assertEqual(log, log1)
|
self.assertEqual(log, log1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user