Add early stopping callback to pytorch trainer (#8581)

* Add early stopping patience and minimum threshold metric must improve to prevent early stopping to pytorch trainer

* Add early stopping test

* Set patience counter to 0 if best metric not defined yet

* Make early stopping a callback. Add callback event for updating the best metric for early stopping callback to trigger on.

* Run make style

* make funciton name sensible

* Improve new argument docstring wording and hope that flakey CI test passes.

* Use on_evaluation callback instead of custom. Remove some debug printing

* Move early stopping arguments and state into early stopping callback

* Run make style

* Remove old code

* Fix docs formatting. make style went rogue on me.

* Remove copied attributes and fix variable

* Add assertions on training arguments instead of mutating them. Move comment out of public docs.

* Make separate test for early stopping callback. Add test of invalid arguments.

* Run make style... I remembered before CI this time!

* appease flake8

* Add EarlyStoppingCallback to callback docs

* Make docstring EarlyStoppingCallabck match other callbacks.

* Fix typo in docs
This commit is contained in:
Colin Brochtrup
2020-11-23 17:25:35 -05:00
committed by GitHub
parent 367f497dec
commit 8ffc01a76a
4 changed files with 95 additions and 0 deletions

View File

@@ -44,6 +44,8 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the
.. autoclass:: transformers.ProgressCallback .. autoclass:: transformers.ProgressCallback
.. autoclass:: transformers.EarlyStoppingCallback
.. autoclass:: transformers.integrations.TensorBoardCallback .. autoclass:: transformers.integrations.TensorBoardCallback
.. autoclass:: transformers.integrations.WandbCallback .. autoclass:: transformers.integrations.WandbCallback

View File

@@ -253,6 +253,7 @@ else:
# Trainer # Trainer
from .trainer_callback import ( from .trainer_callback import (
DefaultFlowCallback, DefaultFlowCallback,
EarlyStoppingCallback,
PrinterCallback, PrinterCallback,
ProgressCallback, ProgressCallback,
TrainerCallback, TrainerCallback,

View File

@@ -21,6 +21,7 @@ import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import numpy as np
from tqdm.auto import tqdm from tqdm.auto import tqdm
from .trainer_utils import EvaluationStrategy from .trainer_utils import EvaluationStrategy
@@ -475,3 +476,62 @@ class PrinterCallback(TrainerCallback):
_ = logs.pop("total_flos", None) _ = logs.pop("total_flos", None)
if state.is_local_process_zero: if state.is_local_process_zero:
print(logs) print(logs)
class EarlyStoppingCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that handles early stopping.
Args:
early_stopping_patience (:obj:`int`):
Use with :obj:`metric_for_best_model` to stop training when the specified metric worsens for
:obj:`early_stopping_patience` evaluation calls.
early_stopping_threshold(:obj:`float`, `optional`):
Use with TrainingArguments :obj:`metric_for_best_model` and :obj:`early_stopping_patience` to denote how
much the specified metric must improve to satisfy early stopping conditions. `
This callback depends on :class:`~transformers.TrainingArguments` argument `load_best_model_at_end` functionality
to set best_metric in :class:`~transformers.TrainerState`.
"""
def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
self.early_stopping_patience = early_stopping_patience
self.early_stopping_threshold = early_stopping_threshold
# early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
self.early_stopping_patience_counter = 0
def check_metric_value(self, args, state, control, metric_value):
# best_metric is set by code for load_best_model
operator = np.greater if args.greater_is_better else np.less
if state.best_metric is None or (
operator(metric_value, state.best_metric)
and abs(metric_value - state.best_metric) > self.early_stopping_threshold
):
self.early_stopping_patience_counter = 0
else:
self.early_stopping_patience_counter += 1
def on_train_begin(self, args, state, control, **kwargs):
assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True"
assert (
args.metric_for_best_model is not None
), "EarlyStoppingCallback requires metric_for_best_model is defined"
assert (
args.evaluation_strategy != EvaluationStrategy.NO
), "EarlyStoppingCallback requires EvaluationStrategy of steps or epoch"
def on_evaluate(self, args, state, control, metrics, **kwargs):
metric_to_check = args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics.get(metric_to_check)
if metric_value is None:
logger.warning(
f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping is disabled"
)
return
self.check_metric_value(args, state, control, metric_value)
if self.early_stopping_patience_counter >= self.early_stopping_patience:
control.should_training_stop = True

View File

@@ -42,6 +42,7 @@ if is_torch_available():
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
EarlyStoppingCallback,
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
GPT2Config, GPT2Config,
@@ -765,6 +766,37 @@ class TrainerIntegrationTest(unittest.TestCase):
train_output = trainer.train() train_output = trainer.train()
self.assertEqual(train_output.global_step, int(self.n_epochs)) self.assertEqual(train_output.global_step, int(self.n_epochs))
def test_early_stopping_callback(self):
# early stopping stops training before num_training_epochs
trainer = get_regression_trainer(
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
load_best_model_at_end=True,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
train_output = trainer.train()
self.assertLess(train_output.global_step, 20 * 64 / 16)
# Invalid inputs to trainer with early stopping callback result in assertion error
trainer = get_regression_trainer(
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1))
self.assertEqual(trainer.state.global_step, 0)
try:
trainer.train()
except AssertionError:
self.assertEqual(trainer.state.global_step, 0)
def test_flos_extraction(self): def test_flos_extraction(self):
trainer = get_regression_trainer(learning_rate=0.1) trainer = get_regression_trainer(learning_rate=0.1)