Small fixes to HP search (#7839)
This commit is contained in:
@@ -520,7 +520,7 @@ class Trainer:
|
|||||||
):
|
):
|
||||||
if self.hp_search_backend is None or trial is None:
|
if self.hp_search_backend is None or trial is None:
|
||||||
return
|
return
|
||||||
self.objective = self.compute_objective(metrics)
|
self.objective = self.compute_objective(metrics.copy())
|
||||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||||
trial.report(self.objective, epoch)
|
trial.report(self.objective, epoch)
|
||||||
if trial.should_prune():
|
if trial.should_prune():
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
|
|||||||
"""
|
"""
|
||||||
loss = metrics.pop("eval_loss", None)
|
loss = metrics.pop("eval_loss", None)
|
||||||
_ = metrics.pop("epoch", None)
|
_ = metrics.pop("epoch", None)
|
||||||
|
_ = metrics.pop("total_flos", None)
|
||||||
return loss if len(metrics) == 0 else sum(metrics.values())
|
return loss if len(metrics) == 0 else sum(metrics.values())
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user