From 4ba203d9d3ab5f6ae8def490cbea44b61798fc54 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 1 Jun 2021 15:58:31 -0700 Subject: [PATCH] [Trainer] add train loss and flops metrics reports (#11980) * add train loss and flops metrics reports * consistency * add train_loss to skip keys * restore on_train_end call timing --- src/transformers/trainer.py | 16 ++++++++++------ tests/test_trainer.py | 10 ++++------ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0673174879..879a9c66d8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1362,20 +1362,24 @@ class Trainer: self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False ) + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() + train_loss = self._total_loss_scalar / self.state.global_step + metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) self.store_flos() metrics["total_flos"] = self.state.total_flos - self.log(metrics) - - self.control = self.callback_handler.on_train_end(args, self.state, self.control) - # add remaining tr_loss - self._total_loss_scalar += tr_loss.item() + metrics["train_loss"] = train_loss self.is_in_train = False self._memory_tracker.stop_and_update_metrics(metrics) - return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics) + self.log(metrics) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + return TrainOutput(self.state.global_step, train_loss, metrics) def _load_state_dict_in_model(self, state_dict): load_result = self.model.load_state_dict(state_dict, strict=False) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index abc31f1d46..89a68792c8 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -311,13 +311,11 @@ class TrainerIntegrationCommon: log_history = state.pop("log_history", None) log_history1 = state1.pop("log_history", None) self.assertEqual(state, state1) + skip_log_keys = ["train_runtime", "train_samples_per_second", "train_steps_per_second", "train_loss"] for log, log1 in zip(log_history, log_history1): - _ = log.pop("train_runtime", None) - _ = log1.pop("train_runtime", None) - _ = log.pop("train_samples_per_second", None) - _ = log1.pop("train_samples_per_second", None) - _ = log.pop("train_steps_per_second", None) - _ = log1.pop("train_steps_per_second", None) + for key in skip_log_keys: + _ = log.pop(key, None) + _ = log1.pop(key, None) self.assertEqual(log, log1)