Add early stopping callback to pytorch trainer (#8581)
* Add early stopping patience and minimum threshold metric must improve to prevent early stopping to pytorch trainer * Add early stopping test * Set patience counter to 0 if best metric not defined yet * Make early stopping a callback. Add callback event for updating the best metric for early stopping callback to trigger on. * Run make style * make funciton name sensible * Improve new argument docstring wording and hope that flakey CI test passes. * Use on_evaluation callback instead of custom. Remove some debug printing * Move early stopping arguments and state into early stopping callback * Run make style * Remove old code * Fix docs formatting. make style went rogue on me. * Remove copied attributes and fix variable * Add assertions on training arguments instead of mutating them. Move comment out of public docs. * Make separate test for early stopping callback. Add test of invalid arguments. * Run make style... I remembered before CI this time! * appease flake8 * Add EarlyStoppingCallback to callback docs * Make docstring EarlyStoppingCallabck match other callbacks. * Fix typo in docs
This commit is contained in:
@@ -42,6 +42,7 @@ if is_torch_available():
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollatorForLanguageModeling,
|
||||
EarlyStoppingCallback,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
GPT2Config,
|
||||
@@ -765,6 +766,37 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, int(self.n_epochs))
|
||||
|
||||
def test_early_stopping_callback(self):
|
||||
# early stopping stops training before num_training_epochs
|
||||
trainer = get_regression_trainer(
|
||||
num_train_epochs=20,
|
||||
gradient_accumulation_steps=1,
|
||||
per_device_train_batch_size=16,
|
||||
load_best_model_at_end=True,
|
||||
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
metric_for_best_model="accuracy",
|
||||
)
|
||||
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
|
||||
train_output = trainer.train()
|
||||
self.assertLess(train_output.global_step, 20 * 64 / 16)
|
||||
|
||||
# Invalid inputs to trainer with early stopping callback result in assertion error
|
||||
trainer = get_regression_trainer(
|
||||
num_train_epochs=20,
|
||||
gradient_accumulation_steps=1,
|
||||
per_device_train_batch_size=16,
|
||||
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
metric_for_best_model="accuracy",
|
||||
)
|
||||
trainer.add_callback(EarlyStoppingCallback(1))
|
||||
self.assertEqual(trainer.state.global_step, 0)
|
||||
try:
|
||||
trainer.train()
|
||||
except AssertionError:
|
||||
self.assertEqual(trainer.state.global_step, 0)
|
||||
|
||||
def test_flos_extraction(self):
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user