Add Trainer support for ReduceLROnPlateau (#23010)

* Add Trainer support for ReduceLROnPlateau

Fixes #16503

* Remove training argument and add default instance

---------

Co-authored-by: mmeloux <maxime.meloux@loria.fr>
This commit is contained in:
Maxime Méloux
2023-04-28 15:17:30 +02:00
committed by GitHub
parent cf7baf4060
commit 9b435204b1
5 changed files with 103 additions and 4 deletions

View File

@@ -1194,7 +1194,9 @@ class TrainingArguments:
f"https://github.com/huggingface/safetensors!"
)
if self.load_best_model_at_end and self.metric_for_best_model is None:
if (
self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU
) and self.metric_for_best_model is None:
self.metric_for_best_model = "loss"
if self.greater_is_better is None and self.metric_for_best_model is not None:
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
@@ -1234,6 +1236,12 @@ class TrainingArguments:
if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16")
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
if self.evaluation_strategy == IntervalStrategy.NO:
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
if not is_torch_available():
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0")
self.optim = OptimizerNames(self.optim)
if self.adafactor:
warnings.warn(