From c1753436dbb8bcbcee183cdd6eba9f08a90d602a Mon Sep 17 00:00:00 2001 From: "Sean (Seok-Won) Yi" Date: Tue, 29 Oct 2024 00:02:22 +0900 Subject: [PATCH] New option called `"best"` for `args.save_strategy`. (#31817) * Add _determine_best_metric and new saving logic. 1. Logic to determine the best logic was separated out from `_save_checkpoint`. 2. In `_maybe_log_save_evaluate`, whether or not a new best metric was achieved is determined after each evaluation, and if the save strategy is "best' then the TrainerControl is updated accordingly. * Added SaveStrategy. Same as IntervalStrategy, but with a new attribute called BEST. * IntervalStrategy -> SaveStrategy * IntervalStratgy -> SaveStrategy for save_strat. * Interval -> Save in docstring. * Updated docstring for save_strategy. * Added SaveStrategy and made according changes. `save_strategy` previously followed `IntervalStrategy` but now follows `SaveStrategy`. Changes were made accordingly to the code and the docstring. * Changes from `make fixup`. * Removed redundant metrics argument. * Added new test_save_best_checkpoint test. 1. Checks for both cases where `metric_for_best_model` is explicitly provided and when it's not provided. 2. The first case should have two checkpoints saved, whereas the second should have three saved. * Changed should_training_end saving logic. The Trainer saves a checkpoints at the end of training by default as long as `save_strategy != SaveStrategy.NO`. This condition was modified to include `SaveStrategy.BEST` because it would be counterintuitive that we'd only want the best checkpoint to be saved but the last one is as well. * `args.metric_for_best_model` default to loss. * Undo metric_for_best_model update. * Remove checking metric_for_best_model. * Added test cases for loss and no metric. * Added error for metric and changed default best_metric. * Removed unused import. * `new_best_metric` -> `is_new_best_metric` Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Applied `is_new_best_metric` to all. Changes were made for consistency and also to fix a potential bug. --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Zach Mueller --- src/transformers/trainer.py | 84 ++++++++++++++++++---------- src/transformers/trainer_callback.py | 8 +-- src/transformers/trainer_utils.py | 7 +++ src/transformers/training_args.py | 14 +++-- src/transformers/training_args_tf.py | 2 +- tests/trainer/test_trainer.py | 83 +++++++++++++++++++++++++++ 6 files changed, 158 insertions(+), 40 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 64cb5c6bd4..4315e54a42 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -117,9 +117,9 @@ from .trainer_utils import ( EvalPrediction, HPSearchBackend, HubStrategy, - IntervalStrategy, PredictionOutput, RemoveColumnsCollator, + SaveStrategy, TrainerMemoryTracker, TrainOutput, check_target_module_exists, @@ -419,6 +419,12 @@ class Trainer: raise ValueError( f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. " ) + if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end: + if args.metric_for_best_model is None: + raise ValueError( + "`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`." + ) + self.args = args self.compute_loss_func = compute_loss_func # Seed must be set before instantiating the model when using model @@ -2998,9 +3004,13 @@ class Trainer: metrics = None if self.control.should_evaluate: metrics = self._evaluate(trial, ignore_keys_for_eval) + is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) + + if self.args.save_strategy == SaveStrategy.BEST: + self.control.should_save = is_new_best_metric if self.control.should_save: - self._save_checkpoint(model, trial, metrics=metrics) + self._save_checkpoint(model, trial) self.control = self.callback_handler.on_save(self.args, self.state, self.control) def _load_rng_state(self, checkpoint): @@ -3077,7 +3087,48 @@ class Trainer: "\nThis won't yield the same results as if the training had not been interrupted." ) - def _save_checkpoint(self, model, trial, metrics=None): + def _determine_best_metric(self, metrics, trial): + """ + Determine if the model should be saved based on the evaluation metrics. + If args.metric_for_best_model is not set, the loss is used. + + Returns: + bool: True if a new best metric was found, else False + """ + is_new_best_metric = False + + if self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + + try: + metric_value = metrics[metric_to_check] + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc + + operator = np.greater if self.args.greater_is_better else np.less + + if self.state.best_metric is None: + self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf") + + if operator(metric_value, self.state.best_metric): + run_dir = self._get_output_dir(trial=trial) + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + output_dir = os.path.join(run_dir, checkpoint_folder) + + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + is_new_best_metric = True + + return is_new_best_metric + + def _save_checkpoint(self, model, trial): # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we # want to save except FullyShardedDDP. # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" @@ -3098,31 +3149,6 @@ class Trainer: # Save RNG state self._save_rng_state(output_dir) - # Determine the new best metric / best model checkpoint - if metrics is not None and self.args.metric_for_best_model is not None: - metric_to_check = self.args.metric_for_best_model - if not metric_to_check.startswith("eval_"): - metric_to_check = f"eval_{metric_to_check}" - try: - metric_value = metrics[metric_to_check] - except KeyError as exc: - raise KeyError( - f"The `metric_for_best_model` training argument is set to '{metric_to_check}', " - f"which is not found in the evaluation metrics. " - f"The available evaluation metrics are: {list(metrics.keys())}. " - f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or " - f"consider changing the `metric_for_best_model` via the TrainingArguments." - ) from exc - - operator = np.greater if self.args.greater_is_better else np.less - if ( - self.state.best_metric is None - or self.state.best_model_checkpoint is None - or operator(metric_value, self.state.best_metric) - ): - self.state.best_metric = metric_value - self.state.best_model_checkpoint = output_dir - # Save the Trainer state if self.args.should_save: # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently @@ -4543,7 +4569,7 @@ class Trainer: # Same for the training arguments torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) - if self.args.save_strategy == IntervalStrategy.STEPS: + if self.args.save_strategy == SaveStrategy.STEPS: commit_message = f"Training in progress, step {self.state.global_step}" else: commit_message = f"Training in progress, epoch {int(self.state.epoch)}" diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 405874acf8..ce9f2a2673 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union import numpy as np from tqdm.auto import tqdm -from .trainer_utils import IntervalStrategy, has_length +from .trainer_utils import IntervalStrategy, SaveStrategy, has_length from .training_args import TrainingArguments from .utils import logging @@ -555,7 +555,7 @@ class DefaultFlowCallback(TrainerCallback): # Save if ( - args.save_strategy == IntervalStrategy.STEPS + args.save_strategy == SaveStrategy.STEPS and state.save_steps > 0 and state.global_step % state.save_steps == 0 ): @@ -565,7 +565,7 @@ class DefaultFlowCallback(TrainerCallback): if state.global_step >= state.max_steps: control.should_training_stop = True # Save the model at the end if we have a save strategy - if args.save_strategy != IntervalStrategy.NO: + if args.save_strategy not in [SaveStrategy.NO, SaveStrategy.BEST]: control.should_save = True return control @@ -580,7 +580,7 @@ class DefaultFlowCallback(TrainerCallback): control.should_evaluate = True # Save - if args.save_strategy == IntervalStrategy.EPOCH: + if args.save_strategy == SaveStrategy.EPOCH: control.should_save = True return control diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 02c298cf7d..42088cd730 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -227,6 +227,13 @@ class IntervalStrategy(ExplicitEnum): EPOCH = "epoch" +class SaveStrategy(ExplicitEnum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + BEST = "best" + + class EvaluationStrategy(ExplicitEnum): NO = "no" STEPS = "steps" diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 485610dd9b..c98e8bc41b 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -33,6 +33,7 @@ from .trainer_utils import ( FSDPOption, HubStrategy, IntervalStrategy, + SaveStrategy, SchedulerType, ) from .utils import ( @@ -349,12 +350,13 @@ class TrainingArguments: - save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`): The checkpoint save strategy to adopt during training. Possible values are: - `"no"`: No save is done during training. - `"epoch"`: Save is done at the end of each epoch. - `"steps"`: Save is done every `save_steps`. + - `"best"`: Save is done whenever a new `best_metric` is achieved. If `"epoch"` or `"steps"` is chosen, saving will also be performed at the very end of training, always. @@ -962,7 +964,7 @@ class TrainingArguments: }, ) logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."}) - save_strategy: Union[IntervalStrategy, str] = field( + save_strategy: Union[SaveStrategy, str] = field( default="steps", metadata={"help": "The checkpoint save strategy to use."}, ) @@ -1580,7 +1582,7 @@ class TrainingArguments: self.eval_strategy = IntervalStrategy(self.eval_strategy) self.logging_strategy = IntervalStrategy(self.logging_strategy) - self.save_strategy = IntervalStrategy(self.save_strategy) + self.save_strategy = SaveStrategy(self.save_strategy) self.hub_strategy = HubStrategy(self.hub_strategy) self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) @@ -1616,7 +1618,7 @@ class TrainingArguments: if self.eval_steps != int(self.eval_steps): raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}") self.eval_steps = int(self.eval_steps) - if self.save_strategy == IntervalStrategy.STEPS and self.save_steps > 1: + if self.save_strategy == SaveStrategy.STEPS and self.save_steps > 1: if self.save_steps != int(self.save_steps): raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}") self.save_steps = int(self.save_steps) @@ -2750,8 +2752,8 @@ class TrainingArguments: 100 ``` """ - self.save_strategy = IntervalStrategy(strategy) - if self.save_strategy == IntervalStrategy.STEPS and steps == 0: + self.save_strategy = SaveStrategy(strategy) + if self.save_strategy == SaveStrategy.STEPS and steps == 0: raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.") self.save_steps = steps self.save_total_limit = total_limit diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py index 9df53c3f1d..3716a78879 100644 --- a/src/transformers/training_args_tf.py +++ b/src/transformers/training_args_tf.py @@ -114,7 +114,7 @@ class TFTrainingArguments(TrainingArguments): Whether to log and evaluate the first `global_step` or not. logging_steps (`int`, *optional*, defaults to 500): Number of update steps between two logs if `logging_strategy="steps"`. - save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`): The checkpoint save strategy to adopt during training. Possible values are: - `"no"`: No save is done during training. diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5c03355785..b6fe807fa4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -4041,6 +4041,89 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): reloaded_tokenizer(test_sentence, padding="max_length").input_ids, ) + def test_save_best_checkpoint(self): + freq = int(64 / self.batch_size) + total = int(self.n_epochs * 64 / self.batch_size) + + # Case 1: args.metric_for_best_model == "accuracy". + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + a=1.5, + b=2.5, + output_dir=tmpdir, + learning_rate=0.1, + eval_strategy="epoch", + save_strategy="best", + metric_for_best_model="accuracy", + compute_metrics=AlmostAccuracy(), + ) + self.assertTrue(trainer.args.metric_for_best_model == "accuracy") + + with patch.object( + trainer, + "_evaluate", + side_effect=[ + {"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0}, + {"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0}, + {"eval_loss": 0.01, "eval_accuracy": 0.64, "epoch": 3.0}, + ], + ): + trainer.train() + + self.assertEqual(len(os.listdir(tmpdir)), 2) + self.check_saved_checkpoints( + output_dir=tmpdir, + freq=freq, + total=total, + ) + + # Case 2: args.metric_for_best_model == "loss". + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + a=1.5, + b=2.5, + output_dir=tmpdir, + learning_rate=0.1, + eval_strategy="epoch", + save_strategy="best", + metric_for_best_model="loss", + compute_metrics=AlmostAccuracy(), + ) + self.assertTrue(trainer.args.metric_for_best_model == "loss") + + with patch.object( + trainer, + "_evaluate", + side_effect=[ + {"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0}, + {"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0}, + {"eval_loss": 0.03, "eval_accuracy": 0.66, "epoch": 3.0}, + ], + ): + trainer.train() + + self.assertEqual(len(os.listdir(tmpdir)), 2) + self.check_saved_checkpoints( + output_dir=tmpdir, + freq=freq, + total=total, + ) + + # Case 3: Metric name not provided; throw error. + with tempfile.TemporaryDirectory() as tmpdir: + with self.assertRaises(ValueError) as context: + trainer = get_regression_trainer( + a=1.5, + b=2.5, + output_dir=tmpdir, + learning_rate=0.1, + eval_strategy="epoch", + save_strategy="best", + compute_metrics=AlmostAccuracy(), + ) + + self.assertIn("`args.metric_for_best_model` must be provided", str(context.exception)) + @require_torch @is_staging_test