Add Trainer support for ReduceLROnPlateau (#23010)
* Add Trainer support for ReduceLROnPlateau Fixes #16503 * Remove training argument and add default instance --------- Co-authored-by: mmeloux <maxime.meloux@loria.fr>
This commit is contained in:
@@ -22,7 +22,7 @@ from typing import Callable, Iterable, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
||||||
|
|
||||||
from .trainer_utils import SchedulerType
|
from .trainer_utils import SchedulerType
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
@@ -49,6 +49,21 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
|||||||
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
def get_reduce_on_plateau_schedule(optimizer: Optimizer):
|
||||||
|
"""
|
||||||
|
Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer ([`~torch.optim.Optimizer`]):
|
||||||
|
The optimizer for which to schedule the learning rate.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
`torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return ReduceLROnPlateau(optimizer)
|
||||||
|
|
||||||
|
|
||||||
def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
|
def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
|
||||||
if current_step < num_warmup_steps:
|
if current_step < num_warmup_steps:
|
||||||
return float(current_step) / float(max(1.0, num_warmup_steps))
|
return float(current_step) / float(max(1.0, num_warmup_steps))
|
||||||
@@ -309,6 +324,7 @@ TYPE_TO_SCHEDULER_FUNCTION = {
|
|||||||
SchedulerType.CONSTANT: get_constant_schedule,
|
SchedulerType.CONSTANT: get_constant_schedule,
|
||||||
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
||||||
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
|
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
|
||||||
|
SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -335,7 +351,7 @@ def get_scheduler(
|
|||||||
"""
|
"""
|
||||||
name = SchedulerType(name)
|
name = SchedulerType(name)
|
||||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||||
if name == SchedulerType.CONSTANT:
|
if name == SchedulerType.CONSTANT or name == SchedulerType.REDUCE_ON_PLATEAU:
|
||||||
return schedule_func(optimizer)
|
return schedule_func(optimizer)
|
||||||
|
|
||||||
# All other schedulers require `num_warmup_steps`
|
# All other schedulers require `num_warmup_steps`
|
||||||
|
|||||||
@@ -1997,7 +1997,9 @@ class Trainer:
|
|||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
if optimizer_was_run and not self.deepspeed:
|
if optimizer_was_run and not self.deepspeed:
|
||||||
self.lr_scheduler.step()
|
# Delay optimizer scheduling until metrics are generated
|
||||||
|
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
self.state.global_step += 1
|
self.state.global_step += 1
|
||||||
@@ -2288,6 +2290,10 @@ class Trainer:
|
|||||||
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
|
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
|
||||||
self._report_to_hp_search(trial, self.state.global_step, metrics)
|
self._report_to_hp_search(trial, self.state.global_step, metrics)
|
||||||
|
|
||||||
|
# Run delayed LR scheduler now that metrics are populated
|
||||||
|
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||||
|
self.lr_scheduler.step(metrics[self.args.metric_for_best_model])
|
||||||
|
|
||||||
if self.control.should_save:
|
if self.control.should_save:
|
||||||
self._save_checkpoint(model, trial, metrics=metrics)
|
self._save_checkpoint(model, trial, metrics=metrics)
|
||||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||||
|
|||||||
@@ -367,6 +367,7 @@ class SchedulerType(ExplicitEnum):
|
|||||||
CONSTANT = "constant"
|
CONSTANT = "constant"
|
||||||
CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
||||||
INVERSE_SQRT = "inverse_sqrt"
|
INVERSE_SQRT = "inverse_sqrt"
|
||||||
|
REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
|
||||||
|
|
||||||
|
|
||||||
class TrainerMemoryTracker:
|
class TrainerMemoryTracker:
|
||||||
|
|||||||
@@ -1194,7 +1194,9 @@ class TrainingArguments:
|
|||||||
f"https://github.com/huggingface/safetensors!"
|
f"https://github.com/huggingface/safetensors!"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.load_best_model_at_end and self.metric_for_best_model is None:
|
if (
|
||||||
|
self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU
|
||||||
|
) and self.metric_for_best_model is None:
|
||||||
self.metric_for_best_model = "loss"
|
self.metric_for_best_model = "loss"
|
||||||
if self.greater_is_better is None and self.metric_for_best_model is not None:
|
if self.greater_is_better is None and self.metric_for_best_model is not None:
|
||||||
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
|
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
|
||||||
@@ -1234,6 +1236,12 @@ class TrainingArguments:
|
|||||||
if not (self.sharded_ddp == "" or not self.sharded_ddp):
|
if not (self.sharded_ddp == "" or not self.sharded_ddp):
|
||||||
raise ValueError("sharded_ddp is not supported with bf16")
|
raise ValueError("sharded_ddp is not supported with bf16")
|
||||||
|
|
||||||
|
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
|
||||||
|
if self.evaluation_strategy == IntervalStrategy.NO:
|
||||||
|
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0")
|
||||||
|
|
||||||
self.optim = OptimizerNames(self.optim)
|
self.optim = OptimizerNames(self.optim)
|
||||||
if self.adafactor:
|
if self.adafactor:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|||||||
@@ -575,6 +575,74 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||||
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
|
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
|
||||||
|
|
||||||
|
def test_reduce_lr_on_plateau_args(self):
|
||||||
|
# test passed arguments for a custom ReduceLROnPlateau scheduler
|
||||||
|
train_dataset = RegressionDataset(length=64)
|
||||||
|
eval_dataset = RegressionDataset(length=64)
|
||||||
|
args = TrainingArguments(
|
||||||
|
"./regression",
|
||||||
|
evaluation_strategy="epoch",
|
||||||
|
metric_for_best_model="eval_loss",
|
||||||
|
)
|
||||||
|
model = RegressionModel()
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=5, cooldown=2)
|
||||||
|
trainer = Trainer(
|
||||||
|
model, args, train_dataset=train_dataset, eval_dataset=eval_dataset, optimizers=(optimizer, lr_scheduler)
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
|
||||||
|
self.assertEqual(trainer.lr_scheduler.factor, 0.2)
|
||||||
|
self.assertEqual(trainer.lr_scheduler.patience, 5)
|
||||||
|
self.assertEqual(trainer.lr_scheduler.cooldown, 2)
|
||||||
|
|
||||||
|
def test_reduce_lr_on_plateau(self):
|
||||||
|
# test the ReduceLROnPlateau scheduler
|
||||||
|
|
||||||
|
class TrainerWithLRLogs(Trainer):
|
||||||
|
def log(self, logs):
|
||||||
|
# the LR is computed after metrics and does not exist for the first epoch
|
||||||
|
if hasattr(self.lr_scheduler, "_last_lr"):
|
||||||
|
logs["learning_rate"] = self.lr_scheduler._last_lr
|
||||||
|
super().log(logs)
|
||||||
|
|
||||||
|
train_dataset = RegressionDataset(length=64)
|
||||||
|
eval_dataset = RegressionDataset(length=64)
|
||||||
|
|
||||||
|
args = TrainingArguments(
|
||||||
|
"./regression",
|
||||||
|
lr_scheduler_type="reduce_lr_on_plateau",
|
||||||
|
evaluation_strategy="epoch",
|
||||||
|
metric_for_best_model="eval_loss",
|
||||||
|
num_train_epochs=10,
|
||||||
|
learning_rate=0.2,
|
||||||
|
)
|
||||||
|
model = RegressionModel()
|
||||||
|
trainer = TrainerWithLRLogs(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
|
||||||
|
patience = trainer.lr_scheduler.patience
|
||||||
|
|
||||||
|
logs = trainer.state.log_history[1:]
|
||||||
|
best_loss = logs[0]["eval_loss"]
|
||||||
|
bad_epochs = 0
|
||||||
|
for i, log in enumerate(logs[:-1]): # Compare learning rate to next epoch's
|
||||||
|
loss = log["eval_loss"]
|
||||||
|
just_decreased = False
|
||||||
|
if loss > best_loss:
|
||||||
|
bad_epochs += 1
|
||||||
|
if bad_epochs > patience:
|
||||||
|
self.assertLess(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
|
||||||
|
just_decreased = True
|
||||||
|
bad_epochs = 0
|
||||||
|
else:
|
||||||
|
best_loss = loss
|
||||||
|
bad_epochs = 0
|
||||||
|
if not just_decreased:
|
||||||
|
self.assertEqual(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
|
||||||
|
|
||||||
def test_adafactor_lr_none(self):
|
def test_adafactor_lr_none(self):
|
||||||
# test the special case where lr=None, since Trainer can't not have lr_scheduler
|
# test the special case where lr=None, since Trainer can't not have lr_scheduler
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user