New option called "best" for args.save_strategy. (#31817)
* Add _determine_best_metric and new saving logic. 1. Logic to determine the best logic was separated out from `_save_checkpoint`. 2. In `_maybe_log_save_evaluate`, whether or not a new best metric was achieved is determined after each evaluation, and if the save strategy is "best' then the TrainerControl is updated accordingly. * Added SaveStrategy. Same as IntervalStrategy, but with a new attribute called BEST. * IntervalStrategy -> SaveStrategy * IntervalStratgy -> SaveStrategy for save_strat. * Interval -> Save in docstring. * Updated docstring for save_strategy. * Added SaveStrategy and made according changes. `save_strategy` previously followed `IntervalStrategy` but now follows `SaveStrategy`. Changes were made accordingly to the code and the docstring. * Changes from `make fixup`. * Removed redundant metrics argument. * Added new test_save_best_checkpoint test. 1. Checks for both cases where `metric_for_best_model` is explicitly provided and when it's not provided. 2. The first case should have two checkpoints saved, whereas the second should have three saved. * Changed should_training_end saving logic. The Trainer saves a checkpoints at the end of training by default as long as `save_strategy != SaveStrategy.NO`. This condition was modified to include `SaveStrategy.BEST` because it would be counterintuitive that we'd only want the best checkpoint to be saved but the last one is as well. * `args.metric_for_best_model` default to loss. * Undo metric_for_best_model update. * Remove checking metric_for_best_model. * Added test cases for loss and no metric. * Added error for metric and changed default best_metric. * Removed unused import. * `new_best_metric` -> `is_new_best_metric` Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Applied `is_new_best_metric` to all. Changes were made for consistency and also to fix a potential bug. --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
committed by
GitHub
parent
8b3b9b48fc
commit
c1753436db
@@ -4041,6 +4041,89 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
reloaded_tokenizer(test_sentence, padding="max_length").input_ids,
|
||||
)
|
||||
|
||||
def test_save_best_checkpoint(self):
|
||||
freq = int(64 / self.batch_size)
|
||||
total = int(self.n_epochs * 64 / self.batch_size)
|
||||
|
||||
# Case 1: args.metric_for_best_model == "accuracy".
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="best",
|
||||
metric_for_best_model="accuracy",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
)
|
||||
self.assertTrue(trainer.args.metric_for_best_model == "accuracy")
|
||||
|
||||
with patch.object(
|
||||
trainer,
|
||||
"_evaluate",
|
||||
side_effect=[
|
||||
{"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0},
|
||||
{"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0},
|
||||
{"eval_loss": 0.01, "eval_accuracy": 0.64, "epoch": 3.0},
|
||||
],
|
||||
):
|
||||
trainer.train()
|
||||
|
||||
self.assertEqual(len(os.listdir(tmpdir)), 2)
|
||||
self.check_saved_checkpoints(
|
||||
output_dir=tmpdir,
|
||||
freq=freq,
|
||||
total=total,
|
||||
)
|
||||
|
||||
# Case 2: args.metric_for_best_model == "loss".
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="best",
|
||||
metric_for_best_model="loss",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
)
|
||||
self.assertTrue(trainer.args.metric_for_best_model == "loss")
|
||||
|
||||
with patch.object(
|
||||
trainer,
|
||||
"_evaluate",
|
||||
side_effect=[
|
||||
{"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0},
|
||||
{"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0},
|
||||
{"eval_loss": 0.03, "eval_accuracy": 0.66, "epoch": 3.0},
|
||||
],
|
||||
):
|
||||
trainer.train()
|
||||
|
||||
self.assertEqual(len(os.listdir(tmpdir)), 2)
|
||||
self.check_saved_checkpoints(
|
||||
output_dir=tmpdir,
|
||||
freq=freq,
|
||||
total=total,
|
||||
)
|
||||
|
||||
# Case 3: Metric name not provided; throw error.
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with self.assertRaises(ValueError) as context:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="best",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
)
|
||||
|
||||
self.assertIn("`args.metric_for_best_model` must be provided", str(context.exception))
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user