Small fixes to HP search (#7839)
This commit is contained in:
@@ -520,7 +520,7 @@ class Trainer:
|
||||
):
|
||||
if self.hp_search_backend is None or trial is None:
|
||||
return
|
||||
self.objective = self.compute_objective(metrics)
|
||||
self.objective = self.compute_objective(metrics.copy())
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
trial.report(self.objective, epoch)
|
||||
if trial.should_prune():
|
||||
|
||||
@@ -112,6 +112,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
|
||||
"""
|
||||
loss = metrics.pop("eval_loss", None)
|
||||
_ = metrics.pop("epoch", None)
|
||||
_ = metrics.pop("total_flos", None)
|
||||
return loss if len(metrics) == 0 else sum(metrics.values())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user