dvclive callback: warn instead of fail when logging non-scalars (#27608)
* dvclive callback: warn instead of fail when logging non-scalars * tests: log lr as scalar
This commit is contained in:
@@ -1680,10 +1680,19 @@ class DVCLiveCallback(TrainerCallback):
|
|||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self.setup(args, state, model)
|
self.setup(args, state, model)
|
||||||
if state.is_world_process_zero:
|
if state.is_world_process_zero:
|
||||||
|
from dvclive.plots import Metric
|
||||||
from dvclive.utils import standardize_metric_name
|
from dvclive.utils import standardize_metric_name
|
||||||
|
|
||||||
for key, value in logs.items():
|
for key, value in logs.items():
|
||||||
self.live.log_metric(standardize_metric_name(key, "dvclive.huggingface"), value)
|
if Metric.could_log(value):
|
||||||
|
self.live.log_metric(standardize_metric_name(key, "dvclive.huggingface"), value)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Trainer is attempting to log a value of "
|
||||||
|
f'"{value}" of type {type(value)} for key "{key}" as a scalar. '
|
||||||
|
"This invocation of DVCLive's Live.log_metric() "
|
||||||
|
"is incorrect so we dropped this attribute."
|
||||||
|
)
|
||||||
self.live.next_step()
|
self.live.next_step()
|
||||||
|
|
||||||
def on_save(self, args, state, control, **kwargs):
|
def on_save(self, args, state, control, **kwargs):
|
||||||
|
|||||||
@@ -672,7 +672,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
def log(self, logs):
|
def log(self, logs):
|
||||||
# the LR is computed after metrics and does not exist for the first epoch
|
# the LR is computed after metrics and does not exist for the first epoch
|
||||||
if hasattr(self.lr_scheduler, "_last_lr"):
|
if hasattr(self.lr_scheduler, "_last_lr"):
|
||||||
logs["learning_rate"] = self.lr_scheduler._last_lr
|
logs["learning_rate"] = self.lr_scheduler._last_lr[0]
|
||||||
super().log(logs)
|
super().log(logs)
|
||||||
|
|
||||||
train_dataset = RegressionDataset(length=64)
|
train_dataset = RegressionDataset(length=64)
|
||||||
@@ -702,14 +702,14 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
if loss > best_loss:
|
if loss > best_loss:
|
||||||
bad_epochs += 1
|
bad_epochs += 1
|
||||||
if bad_epochs > patience:
|
if bad_epochs > patience:
|
||||||
self.assertLess(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
|
self.assertLess(logs[i + 1]["learning_rate"], log["learning_rate"])
|
||||||
just_decreased = True
|
just_decreased = True
|
||||||
bad_epochs = 0
|
bad_epochs = 0
|
||||||
else:
|
else:
|
||||||
best_loss = loss
|
best_loss = loss
|
||||||
bad_epochs = 0
|
bad_epochs = 0
|
||||||
if not just_decreased:
|
if not just_decreased:
|
||||||
self.assertEqual(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
|
self.assertEqual(logs[i + 1]["learning_rate"], log["learning_rate"])
|
||||||
|
|
||||||
def test_adafactor_lr_none(self):
|
def test_adafactor_lr_none(self):
|
||||||
# test the special case where lr=None, since Trainer can't not have lr_scheduler
|
# test the special case where lr=None, since Trainer can't not have lr_scheduler
|
||||||
|
|||||||
Reference in New Issue
Block a user