From 06fc3954a1ed4bf8351143cf561761d9e8fda65b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 22 Oct 2020 14:26:55 -0400 Subject: [PATCH] Only log total_flos at the end of training (#7981) * Only log total_flos at the end of training * Fix test --- src/transformers/trainer.py | 7 ++++--- src/transformers/trainer_utils.py | 7 +------ tests/test_trainer_callback.py | 2 +- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 001ebdc018..21a452e2bd 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -830,6 +830,10 @@ class Trainer: state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) self.model.load_state_dict(state_dict) + if self._total_flos is not None: + self.store_flos() + self.log({"total_flos": self.state.total_flos}) + self.control = self.callback_handler.on_train_end(self.args, self.state, self.control) return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step) @@ -1013,9 +1017,6 @@ class Trainer: return self._log(logs) if self.state.epoch is not None: logs["epoch"] = self.state.epoch - if self._total_flos is not None: - self.store_flos() - logs["total_flos"] = self.state.total_flos self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) output = {**logs, **{"step": self.state.global_step}} diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index fd35e6b55f..2a5deb6c57 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -114,12 +114,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float: metrics = copy.deepcopy(metrics) loss = metrics.pop("eval_loss", None) _ = metrics.pop("epoch", None) - _ = metrics.pop("total_flos", None) - if len(metrics) != 0: - raise RuntimeError( - "Metrics contains more entries than just 'eval_loss', 'epoch' and 'total_flos', please provide your own compute_objective function." - ) - return loss + return loss if len(metrics) == 0 else sum(metrics.values()) def default_hp_space_optuna(trial) -> Dict[str, float]: diff --git a/tests/test_trainer_callback.py b/tests/test_trainer_callback.py index ad3662adf2..133c4e29f2 100644 --- a/tests/test_trainer_callback.py +++ b/tests/test_trainer_callback.py @@ -125,7 +125,7 @@ class TrainerCallbackTest(unittest.TestCase): expected_events.append("on_epoch_end") if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH: expected_events += evaluation_events.copy() - expected_events.append("on_train_end") + expected_events += ["on_log", "on_train_end"] return expected_events def test_init_callback(self):