From eb186bc14ee0daae1fb80e5bbbdd12ae71ddfa36 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 16 Oct 2020 03:23:44 -0400 Subject: [PATCH] Small fixes to HP search (#7839) --- src/transformers/trainer.py | 2 +- src/transformers/trainer_utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 500a95a520..5d471f73b7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -520,7 +520,7 @@ class Trainer: ): if self.hp_search_backend is None or trial is None: return - self.objective = self.compute_objective(metrics) + self.objective = self.compute_objective(metrics.copy()) if self.hp_search_backend == HPSearchBackend.OPTUNA: trial.report(self.objective, epoch) if trial.should_prune(): diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 96757a922e..ef3eaa1b05 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -112,6 +112,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float: """ loss = metrics.pop("eval_loss", None) _ = metrics.pop("epoch", None) + _ = metrics.pop("total_flos", None) return loss if len(metrics) == 0 else sum(metrics.values())