From 9b435204b162fe89555e654cc524c4275cda521a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maxime=20M=C3=A9loux?= Date: Fri, 28 Apr 2023 15:17:30 +0200 Subject: [PATCH] Add Trainer support for ReduceLROnPlateau (#23010) * Add Trainer support for ReduceLROnPlateau Fixes #16503 * Remove training argument and add default instance --------- Co-authored-by: mmeloux --- src/transformers/optimization.py | 20 ++++++++- src/transformers/trainer.py | 8 +++- src/transformers/trainer_utils.py | 1 + src/transformers/training_args.py | 10 ++++- tests/trainer/test_trainer.py | 68 +++++++++++++++++++++++++++++++ 5 files changed, 103 insertions(+), 4 deletions(-) diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index d3dd43bff2..8c9430fb6f 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -22,7 +22,7 @@ from typing import Callable, Iterable, Optional, Tuple, Union import torch from torch import nn from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR +from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau from .trainer_utils import SchedulerType from .utils import logging @@ -49,6 +49,21 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) +def get_reduce_on_plateau_schedule(optimizer: Optimizer): + """ + Create a schedule with a constant learning rate that decreases when a metric has stopped improving. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + + Return: + `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule. + """ + + return ReduceLROnPlateau(optimizer) + + def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1.0, num_warmup_steps)) @@ -309,6 +324,7 @@ TYPE_TO_SCHEDULER_FUNCTION = { SchedulerType.CONSTANT: get_constant_schedule, SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule, + SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule, } @@ -335,7 +351,7 @@ def get_scheduler( """ name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - if name == SchedulerType.CONSTANT: + if name == SchedulerType.CONSTANT or name == SchedulerType.REDUCE_ON_PLATEAU: return schedule_func(optimizer) # All other schedulers require `num_warmup_steps` diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 147cdc8d9a..6019dc2099 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1997,7 +1997,9 @@ class Trainer: self.optimizer.step() if optimizer_was_run and not self.deepspeed: - self.lr_scheduler.step() + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() model.zero_grad() self.state.global_step += 1 @@ -2288,6 +2290,10 @@ class Trainer: metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) self._report_to_hp_search(trial, self.state.global_step, metrics) + # Run delayed LR scheduler now that metrics are populated + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(metrics[self.args.metric_for_best_model]) + if self.control.should_save: self._save_checkpoint(model, trial, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index a213e4b1f4..bb44c4c1ab 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -367,6 +367,7 @@ class SchedulerType(ExplicitEnum): CONSTANT = "constant" CONSTANT_WITH_WARMUP = "constant_with_warmup" INVERSE_SQRT = "inverse_sqrt" + REDUCE_ON_PLATEAU = "reduce_lr_on_plateau" class TrainerMemoryTracker: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 55945a4382..80df5d2100 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1194,7 +1194,9 @@ class TrainingArguments: f"https://github.com/huggingface/safetensors!" ) - if self.load_best_model_at_end and self.metric_for_best_model is None: + if ( + self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU + ) and self.metric_for_best_model is None: self.metric_for_best_model = "loss" if self.greater_is_better is None and self.metric_for_best_model is not None: self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"] @@ -1234,6 +1236,12 @@ class TrainingArguments: if not (self.sharded_ddp == "" or not self.sharded_ddp): raise ValueError("sharded_ddp is not supported with bf16") + if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU: + if self.evaluation_strategy == IntervalStrategy.NO: + raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy") + if not is_torch_available(): + raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0") + self.optim = OptimizerNames(self.optim) if self.adafactor: warnings.warn( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 78b6afeacd..63a1263588 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -575,6 +575,74 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): self.assertFalse(torch.allclose(trainer.model.b, b)) self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0) + def test_reduce_lr_on_plateau_args(self): + # test passed arguments for a custom ReduceLROnPlateau scheduler + train_dataset = RegressionDataset(length=64) + eval_dataset = RegressionDataset(length=64) + args = TrainingArguments( + "./regression", + evaluation_strategy="epoch", + metric_for_best_model="eval_loss", + ) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=1.0) + lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=5, cooldown=2) + trainer = Trainer( + model, args, train_dataset=train_dataset, eval_dataset=eval_dataset, optimizers=(optimizer, lr_scheduler) + ) + trainer.train() + + self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) + self.assertEqual(trainer.lr_scheduler.factor, 0.2) + self.assertEqual(trainer.lr_scheduler.patience, 5) + self.assertEqual(trainer.lr_scheduler.cooldown, 2) + + def test_reduce_lr_on_plateau(self): + # test the ReduceLROnPlateau scheduler + + class TrainerWithLRLogs(Trainer): + def log(self, logs): + # the LR is computed after metrics and does not exist for the first epoch + if hasattr(self.lr_scheduler, "_last_lr"): + logs["learning_rate"] = self.lr_scheduler._last_lr + super().log(logs) + + train_dataset = RegressionDataset(length=64) + eval_dataset = RegressionDataset(length=64) + + args = TrainingArguments( + "./regression", + lr_scheduler_type="reduce_lr_on_plateau", + evaluation_strategy="epoch", + metric_for_best_model="eval_loss", + num_train_epochs=10, + learning_rate=0.2, + ) + model = RegressionModel() + trainer = TrainerWithLRLogs(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + trainer.train() + + self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) + patience = trainer.lr_scheduler.patience + + logs = trainer.state.log_history[1:] + best_loss = logs[0]["eval_loss"] + bad_epochs = 0 + for i, log in enumerate(logs[:-1]): # Compare learning rate to next epoch's + loss = log["eval_loss"] + just_decreased = False + if loss > best_loss: + bad_epochs += 1 + if bad_epochs > patience: + self.assertLess(logs[i + 1]["learning_rate"][0], log["learning_rate"][0]) + just_decreased = True + bad_epochs = 0 + else: + best_loss = loss + bad_epochs = 0 + if not just_decreased: + self.assertEqual(logs[i + 1]["learning_rate"][0], log["learning_rate"][0]) + def test_adafactor_lr_none(self): # test the special case where lr=None, since Trainer can't not have lr_scheduler