Add Trainer support for ReduceLROnPlateau (#23010)
* Add Trainer support for ReduceLROnPlateau Fixes #16503 * Remove training argument and add default instance --------- Co-authored-by: mmeloux <maxime.meloux@loria.fr>
This commit is contained in:
@@ -575,6 +575,74 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
|
||||
|
||||
def test_reduce_lr_on_plateau_args(self):
|
||||
# test passed arguments for a custom ReduceLROnPlateau scheduler
|
||||
train_dataset = RegressionDataset(length=64)
|
||||
eval_dataset = RegressionDataset(length=64)
|
||||
args = TrainingArguments(
|
||||
"./regression",
|
||||
evaluation_strategy="epoch",
|
||||
metric_for_best_model="eval_loss",
|
||||
)
|
||||
model = RegressionModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
|
||||
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=5, cooldown=2)
|
||||
trainer = Trainer(
|
||||
model, args, train_dataset=train_dataset, eval_dataset=eval_dataset, optimizers=(optimizer, lr_scheduler)
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
|
||||
self.assertEqual(trainer.lr_scheduler.factor, 0.2)
|
||||
self.assertEqual(trainer.lr_scheduler.patience, 5)
|
||||
self.assertEqual(trainer.lr_scheduler.cooldown, 2)
|
||||
|
||||
def test_reduce_lr_on_plateau(self):
|
||||
# test the ReduceLROnPlateau scheduler
|
||||
|
||||
class TrainerWithLRLogs(Trainer):
|
||||
def log(self, logs):
|
||||
# the LR is computed after metrics and does not exist for the first epoch
|
||||
if hasattr(self.lr_scheduler, "_last_lr"):
|
||||
logs["learning_rate"] = self.lr_scheduler._last_lr
|
||||
super().log(logs)
|
||||
|
||||
train_dataset = RegressionDataset(length=64)
|
||||
eval_dataset = RegressionDataset(length=64)
|
||||
|
||||
args = TrainingArguments(
|
||||
"./regression",
|
||||
lr_scheduler_type="reduce_lr_on_plateau",
|
||||
evaluation_strategy="epoch",
|
||||
metric_for_best_model="eval_loss",
|
||||
num_train_epochs=10,
|
||||
learning_rate=0.2,
|
||||
)
|
||||
model = RegressionModel()
|
||||
trainer = TrainerWithLRLogs(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||
trainer.train()
|
||||
|
||||
self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
|
||||
patience = trainer.lr_scheduler.patience
|
||||
|
||||
logs = trainer.state.log_history[1:]
|
||||
best_loss = logs[0]["eval_loss"]
|
||||
bad_epochs = 0
|
||||
for i, log in enumerate(logs[:-1]): # Compare learning rate to next epoch's
|
||||
loss = log["eval_loss"]
|
||||
just_decreased = False
|
||||
if loss > best_loss:
|
||||
bad_epochs += 1
|
||||
if bad_epochs > patience:
|
||||
self.assertLess(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
|
||||
just_decreased = True
|
||||
bad_epochs = 0
|
||||
else:
|
||||
best_loss = loss
|
||||
bad_epochs = 0
|
||||
if not just_decreased:
|
||||
self.assertEqual(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
|
||||
|
||||
def test_adafactor_lr_none(self):
|
||||
# test the special case where lr=None, since Trainer can't not have lr_scheduler
|
||||
|
||||
|
||||
Reference in New Issue
Block a user