Add early stopping callback to pytorch trainer (#8581)
* Add early stopping patience and minimum threshold metric must improve to prevent early stopping to pytorch trainer * Add early stopping test * Set patience counter to 0 if best metric not defined yet * Make early stopping a callback. Add callback event for updating the best metric for early stopping callback to trigger on. * Run make style * make funciton name sensible * Improve new argument docstring wording and hope that flakey CI test passes. * Use on_evaluation callback instead of custom. Remove some debug printing * Move early stopping arguments and state into early stopping callback * Run make style * Remove old code * Fix docs formatting. make style went rogue on me. * Remove copied attributes and fix variable * Add assertions on training arguments instead of mutating them. Move comment out of public docs. * Make separate test for early stopping callback. Add test of invalid arguments. * Run make style... I remembered before CI this time! * appease flake8 * Add EarlyStoppingCallback to callback docs * Make docstring EarlyStoppingCallabck match other callbacks. * Fix typo in docs
This commit is contained in:
@@ -44,6 +44,8 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the
|
|||||||
|
|
||||||
.. autoclass:: transformers.ProgressCallback
|
.. autoclass:: transformers.ProgressCallback
|
||||||
|
|
||||||
|
.. autoclass:: transformers.EarlyStoppingCallback
|
||||||
|
|
||||||
.. autoclass:: transformers.integrations.TensorBoardCallback
|
.. autoclass:: transformers.integrations.TensorBoardCallback
|
||||||
|
|
||||||
.. autoclass:: transformers.integrations.WandbCallback
|
.. autoclass:: transformers.integrations.WandbCallback
|
||||||
|
|||||||
@@ -253,6 +253,7 @@ else:
|
|||||||
# Trainer
|
# Trainer
|
||||||
from .trainer_callback import (
|
from .trainer_callback import (
|
||||||
DefaultFlowCallback,
|
DefaultFlowCallback,
|
||||||
|
EarlyStoppingCallback,
|
||||||
PrinterCallback,
|
PrinterCallback,
|
||||||
ProgressCallback,
|
ProgressCallback,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import json
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from .trainer_utils import EvaluationStrategy
|
from .trainer_utils import EvaluationStrategy
|
||||||
@@ -475,3 +476,62 @@ class PrinterCallback(TrainerCallback):
|
|||||||
_ = logs.pop("total_flos", None)
|
_ = logs.pop("total_flos", None)
|
||||||
if state.is_local_process_zero:
|
if state.is_local_process_zero:
|
||||||
print(logs)
|
print(logs)
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyStoppingCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
A :class:`~transformers.TrainerCallback` that handles early stopping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
early_stopping_patience (:obj:`int`):
|
||||||
|
Use with :obj:`metric_for_best_model` to stop training when the specified metric worsens for
|
||||||
|
:obj:`early_stopping_patience` evaluation calls.
|
||||||
|
early_stopping_threshold(:obj:`float`, `optional`):
|
||||||
|
Use with TrainingArguments :obj:`metric_for_best_model` and :obj:`early_stopping_patience` to denote how
|
||||||
|
much the specified metric must improve to satisfy early stopping conditions. `
|
||||||
|
|
||||||
|
This callback depends on :class:`~transformers.TrainingArguments` argument `load_best_model_at_end` functionality
|
||||||
|
to set best_metric in :class:`~transformers.TrainerState`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
|
||||||
|
self.early_stopping_patience = early_stopping_patience
|
||||||
|
self.early_stopping_threshold = early_stopping_threshold
|
||||||
|
# early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
|
||||||
|
self.early_stopping_patience_counter = 0
|
||||||
|
|
||||||
|
def check_metric_value(self, args, state, control, metric_value):
|
||||||
|
# best_metric is set by code for load_best_model
|
||||||
|
operator = np.greater if args.greater_is_better else np.less
|
||||||
|
if state.best_metric is None or (
|
||||||
|
operator(metric_value, state.best_metric)
|
||||||
|
and abs(metric_value - state.best_metric) > self.early_stopping_threshold
|
||||||
|
):
|
||||||
|
self.early_stopping_patience_counter = 0
|
||||||
|
else:
|
||||||
|
self.early_stopping_patience_counter += 1
|
||||||
|
|
||||||
|
def on_train_begin(self, args, state, control, **kwargs):
|
||||||
|
assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True"
|
||||||
|
assert (
|
||||||
|
args.metric_for_best_model is not None
|
||||||
|
), "EarlyStoppingCallback requires metric_for_best_model is defined"
|
||||||
|
assert (
|
||||||
|
args.evaluation_strategy != EvaluationStrategy.NO
|
||||||
|
), "EarlyStoppingCallback requires EvaluationStrategy of steps or epoch"
|
||||||
|
|
||||||
|
def on_evaluate(self, args, state, control, metrics, **kwargs):
|
||||||
|
metric_to_check = args.metric_for_best_model
|
||||||
|
if not metric_to_check.startswith("eval_"):
|
||||||
|
metric_to_check = f"eval_{metric_to_check}"
|
||||||
|
metric_value = metrics.get(metric_to_check)
|
||||||
|
|
||||||
|
if metric_value is None:
|
||||||
|
logger.warning(
|
||||||
|
f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping is disabled"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.check_metric_value(args, state, control, metric_value)
|
||||||
|
if self.early_stopping_patience_counter >= self.early_stopping_patience:
|
||||||
|
control.should_training_stop = True
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ if is_torch_available():
|
|||||||
AutoModelForMaskedLM,
|
AutoModelForMaskedLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
DataCollatorForLanguageModeling,
|
DataCollatorForLanguageModeling,
|
||||||
|
EarlyStoppingCallback,
|
||||||
GlueDataset,
|
GlueDataset,
|
||||||
GlueDataTrainingArguments,
|
GlueDataTrainingArguments,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
@@ -765,6 +766,37 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
train_output = trainer.train()
|
train_output = trainer.train()
|
||||||
self.assertEqual(train_output.global_step, int(self.n_epochs))
|
self.assertEqual(train_output.global_step, int(self.n_epochs))
|
||||||
|
|
||||||
|
def test_early_stopping_callback(self):
|
||||||
|
# early stopping stops training before num_training_epochs
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
num_train_epochs=20,
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
per_device_train_batch_size=16,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||||
|
compute_metrics=AlmostAccuracy(),
|
||||||
|
metric_for_best_model="accuracy",
|
||||||
|
)
|
||||||
|
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
|
||||||
|
train_output = trainer.train()
|
||||||
|
self.assertLess(train_output.global_step, 20 * 64 / 16)
|
||||||
|
|
||||||
|
# Invalid inputs to trainer with early stopping callback result in assertion error
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
num_train_epochs=20,
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
per_device_train_batch_size=16,
|
||||||
|
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||||
|
compute_metrics=AlmostAccuracy(),
|
||||||
|
metric_for_best_model="accuracy",
|
||||||
|
)
|
||||||
|
trainer.add_callback(EarlyStoppingCallback(1))
|
||||||
|
self.assertEqual(trainer.state.global_step, 0)
|
||||||
|
try:
|
||||||
|
trainer.train()
|
||||||
|
except AssertionError:
|
||||||
|
self.assertEqual(trainer.state.global_step, 0)
|
||||||
|
|
||||||
def test_flos_extraction(self):
|
def test_flos_extraction(self):
|
||||||
trainer = get_regression_trainer(learning_rate=0.1)
|
trainer = get_regression_trainer(learning_rate=0.1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user