Remove speed metrics from default compute objective (#10107)

This commit is contained in:
Shiva Zamani
2021-02-09 17:03:02 -07:00
committed by GitHub
parent 7c7962ba89
commit 85395e4901

View File

@@ -131,6 +131,10 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
metrics = copy.deepcopy(metrics) metrics = copy.deepcopy(metrics)
loss = metrics.pop("eval_loss", None) loss = metrics.pop("eval_loss", None)
_ = metrics.pop("epoch", None) _ = metrics.pop("epoch", None)
# Remove speed metrics
speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_samples_per_second")]
for sm in speed_metrics:
_ = metrics.pop(sm, None)
return loss if len(metrics) == 0 else sum(metrics.values()) return loss if len(metrics) == 0 else sum(metrics.values())