From 8ffc01a76ad4c446b16322c3b893a8a3f39c14c0 Mon Sep 17 00:00:00 2001 From: Colin Brochtrup Date: Mon, 23 Nov 2020 17:25:35 -0500 Subject: [PATCH] Add early stopping callback to pytorch trainer (#8581) * Add early stopping patience and minimum threshold metric must improve to prevent early stopping to pytorch trainer * Add early stopping test * Set patience counter to 0 if best metric not defined yet * Make early stopping a callback. Add callback event for updating the best metric for early stopping callback to trigger on. * Run make style * make funciton name sensible * Improve new argument docstring wording and hope that flakey CI test passes. * Use on_evaluation callback instead of custom. Remove some debug printing * Move early stopping arguments and state into early stopping callback * Run make style * Remove old code * Fix docs formatting. make style went rogue on me. * Remove copied attributes and fix variable * Add assertions on training arguments instead of mutating them. Move comment out of public docs. * Make separate test for early stopping callback. Add test of invalid arguments. * Run make style... I remembered before CI this time! * appease flake8 * Add EarlyStoppingCallback to callback docs * Make docstring EarlyStoppingCallabck match other callbacks. * Fix typo in docs --- docs/source/main_classes/callback.rst | 2 + src/transformers/__init__.py | 1 + src/transformers/trainer_callback.py | 60 +++++++++++++++++++++++++++ tests/test_trainer.py | 32 ++++++++++++++ 4 files changed, 95 insertions(+) diff --git a/docs/source/main_classes/callback.rst b/docs/source/main_classes/callback.rst index f146244c1f..4f7d8d27fc 100644 --- a/docs/source/main_classes/callback.rst +++ b/docs/source/main_classes/callback.rst @@ -44,6 +44,8 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the .. autoclass:: transformers.ProgressCallback +.. autoclass:: transformers.EarlyStoppingCallback + .. autoclass:: transformers.integrations.TensorBoardCallback .. autoclass:: transformers.integrations.WandbCallback diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a92fb48812..1eede92505 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -253,6 +253,7 @@ else: # Trainer from .trainer_callback import ( DefaultFlowCallback, + EarlyStoppingCallback, PrinterCallback, ProgressCallback, TrainerCallback, diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 01be518da1..1ad546bc4f 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -21,6 +21,7 @@ import json from dataclasses import dataclass from typing import Dict, List, Optional, Union +import numpy as np from tqdm.auto import tqdm from .trainer_utils import EvaluationStrategy @@ -475,3 +476,62 @@ class PrinterCallback(TrainerCallback): _ = logs.pop("total_flos", None) if state.is_local_process_zero: print(logs) + + +class EarlyStoppingCallback(TrainerCallback): + """ + A :class:`~transformers.TrainerCallback` that handles early stopping. + + Args: + early_stopping_patience (:obj:`int`): + Use with :obj:`metric_for_best_model` to stop training when the specified metric worsens for + :obj:`early_stopping_patience` evaluation calls. + early_stopping_threshold(:obj:`float`, `optional`): + Use with TrainingArguments :obj:`metric_for_best_model` and :obj:`early_stopping_patience` to denote how + much the specified metric must improve to satisfy early stopping conditions. ` + + This callback depends on :class:`~transformers.TrainingArguments` argument `load_best_model_at_end` functionality + to set best_metric in :class:`~transformers.TrainerState`. + """ + + def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0): + self.early_stopping_patience = early_stopping_patience + self.early_stopping_threshold = early_stopping_threshold + # early_stopping_patience_counter denotes the number of times validation metrics failed to improve. + self.early_stopping_patience_counter = 0 + + def check_metric_value(self, args, state, control, metric_value): + # best_metric is set by code for load_best_model + operator = np.greater if args.greater_is_better else np.less + if state.best_metric is None or ( + operator(metric_value, state.best_metric) + and abs(metric_value - state.best_metric) > self.early_stopping_threshold + ): + self.early_stopping_patience_counter = 0 + else: + self.early_stopping_patience_counter += 1 + + def on_train_begin(self, args, state, control, **kwargs): + assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True" + assert ( + args.metric_for_best_model is not None + ), "EarlyStoppingCallback requires metric_for_best_model is defined" + assert ( + args.evaluation_strategy != EvaluationStrategy.NO + ), "EarlyStoppingCallback requires EvaluationStrategy of steps or epoch" + + def on_evaluate(self, args, state, control, metrics, **kwargs): + metric_to_check = args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics.get(metric_to_check) + + if metric_value is None: + logger.warning( + f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping is disabled" + ) + return + + self.check_metric_value(args, state, control, metric_value) + if self.early_stopping_patience_counter >= self.early_stopping_patience: + control.should_training_stop = True diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 5d80654d48..3a5916d19a 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -42,6 +42,7 @@ if is_torch_available(): AutoModelForMaskedLM, AutoModelForSequenceClassification, DataCollatorForLanguageModeling, + EarlyStoppingCallback, GlueDataset, GlueDataTrainingArguments, GPT2Config, @@ -765,6 +766,37 @@ class TrainerIntegrationTest(unittest.TestCase): train_output = trainer.train() self.assertEqual(train_output.global_step, int(self.n_epochs)) + def test_early_stopping_callback(self): + # early stopping stops training before num_training_epochs + trainer = get_regression_trainer( + num_train_epochs=20, + gradient_accumulation_steps=1, + per_device_train_batch_size=16, + load_best_model_at_end=True, + evaluation_strategy=EvaluationStrategy.EPOCH, + compute_metrics=AlmostAccuracy(), + metric_for_best_model="accuracy", + ) + trainer.add_callback(EarlyStoppingCallback(1, 0.0001)) + train_output = trainer.train() + self.assertLess(train_output.global_step, 20 * 64 / 16) + + # Invalid inputs to trainer with early stopping callback result in assertion error + trainer = get_regression_trainer( + num_train_epochs=20, + gradient_accumulation_steps=1, + per_device_train_batch_size=16, + evaluation_strategy=EvaluationStrategy.EPOCH, + compute_metrics=AlmostAccuracy(), + metric_for_best_model="accuracy", + ) + trainer.add_callback(EarlyStoppingCallback(1)) + self.assertEqual(trainer.state.global_step, 0) + try: + trainer.train() + except AssertionError: + self.assertEqual(trainer.state.global_step, 0) + def test_flos_extraction(self): trainer = get_regression_trainer(learning_rate=0.1)