New option called "best" for args.save_strategy. (#31817)

* Add _determine_best_metric and new saving logic.

1. Logic to determine the best logic was separated out from
`_save_checkpoint`.
2. In `_maybe_log_save_evaluate`, whether or not a new best metric was
achieved is determined after each evaluation, and if the save strategy
is "best' then the TrainerControl is updated accordingly.

* Added SaveStrategy.

Same as IntervalStrategy, but with a new attribute called BEST.

* IntervalStrategy -> SaveStrategy

* IntervalStratgy -> SaveStrategy for save_strat.

* Interval -> Save in docstring.

* Updated docstring for save_strategy.

* Added SaveStrategy and made according changes.

`save_strategy` previously followed `IntervalStrategy` but now follows
`SaveStrategy`.

Changes were made accordingly to the code and the docstring.

* Changes from `make fixup`.

* Removed redundant metrics argument.

* Added new test_save_best_checkpoint test.

1. Checks for both cases where `metric_for_best_model` is explicitly
provided and when it's not provided.
2. The first case should have two checkpoints saved, whereas the second
should have three saved.

* Changed should_training_end saving logic.

The Trainer saves a checkpoints at the end of training by default as
long as `save_strategy != SaveStrategy.NO`. This condition was modified
to include `SaveStrategy.BEST` because it would be counterintuitive that
we'd only want the best checkpoint to be saved but the last one is as
well.

* `args.metric_for_best_model` default to loss.

* Undo metric_for_best_model update.

* Remove checking metric_for_best_model.

* Added test cases for loss and no metric.

* Added error for metric and changed default best_metric.

* Removed unused import.

* `new_best_metric` -> `is_new_best_metric`

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Applied `is_new_best_metric` to all.

Changes were made for consistency and also to fix a potential bug.

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
Sean (Seok-Won) Yi
2024-10-29 00:02:22 +09:00
committed by GitHub
parent 8b3b9b48fc
commit c1753436db
6 changed files with 158 additions and 40 deletions

View File

@@ -4041,6 +4041,89 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
reloaded_tokenizer(test_sentence, padding="max_length").input_ids,
)
def test_save_best_checkpoint(self):
freq = int(64 / self.batch_size)
total = int(self.n_epochs * 64 / self.batch_size)
# Case 1: args.metric_for_best_model == "accuracy".
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
eval_strategy="epoch",
save_strategy="best",
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
)
self.assertTrue(trainer.args.metric_for_best_model == "accuracy")
with patch.object(
trainer,
"_evaluate",
side_effect=[
{"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0},
{"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0},
{"eval_loss": 0.01, "eval_accuracy": 0.64, "epoch": 3.0},
],
):
trainer.train()
self.assertEqual(len(os.listdir(tmpdir)), 2)
self.check_saved_checkpoints(
output_dir=tmpdir,
freq=freq,
total=total,
)
# Case 2: args.metric_for_best_model == "loss".
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
eval_strategy="epoch",
save_strategy="best",
metric_for_best_model="loss",
compute_metrics=AlmostAccuracy(),
)
self.assertTrue(trainer.args.metric_for_best_model == "loss")
with patch.object(
trainer,
"_evaluate",
side_effect=[
{"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0},
{"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0},
{"eval_loss": 0.03, "eval_accuracy": 0.66, "epoch": 3.0},
],
):
trainer.train()
self.assertEqual(len(os.listdir(tmpdir)), 2)
self.check_saved_checkpoints(
output_dir=tmpdir,
freq=freq,
total=total,
)
# Case 3: Metric name not provided; throw error.
with tempfile.TemporaryDirectory() as tmpdir:
with self.assertRaises(ValueError) as context:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
eval_strategy="epoch",
save_strategy="best",
compute_metrics=AlmostAccuracy(),
)
self.assertIn("`args.metric_for_best_model` must be provided", str(context.exception))
@require_torch
@is_staging_test