dvclive callback: warn instead of fail when logging non-scalars (#27608)
* dvclive callback: warn instead of fail when logging non-scalars * tests: log lr as scalar
This commit is contained in:
@@ -672,7 +672,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
def log(self, logs):
|
||||
# the LR is computed after metrics and does not exist for the first epoch
|
||||
if hasattr(self.lr_scheduler, "_last_lr"):
|
||||
logs["learning_rate"] = self.lr_scheduler._last_lr
|
||||
logs["learning_rate"] = self.lr_scheduler._last_lr[0]
|
||||
super().log(logs)
|
||||
|
||||
train_dataset = RegressionDataset(length=64)
|
||||
@@ -702,14 +702,14 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
if loss > best_loss:
|
||||
bad_epochs += 1
|
||||
if bad_epochs > patience:
|
||||
self.assertLess(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
|
||||
self.assertLess(logs[i + 1]["learning_rate"], log["learning_rate"])
|
||||
just_decreased = True
|
||||
bad_epochs = 0
|
||||
else:
|
||||
best_loss = loss
|
||||
bad_epochs = 0
|
||||
if not just_decreased:
|
||||
self.assertEqual(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
|
||||
self.assertEqual(logs[i + 1]["learning_rate"], log["learning_rate"])
|
||||
|
||||
def test_adafactor_lr_none(self):
|
||||
# test the special case where lr=None, since Trainer can't not have lr_scheduler
|
||||
|
||||
Reference in New Issue
Block a user