New option called "best" for args.save_strategy. (#31817)
* Add _determine_best_metric and new saving logic. 1. Logic to determine the best logic was separated out from `_save_checkpoint`. 2. In `_maybe_log_save_evaluate`, whether or not a new best metric was achieved is determined after each evaluation, and if the save strategy is "best' then the TrainerControl is updated accordingly. * Added SaveStrategy. Same as IntervalStrategy, but with a new attribute called BEST. * IntervalStrategy -> SaveStrategy * IntervalStratgy -> SaveStrategy for save_strat. * Interval -> Save in docstring. * Updated docstring for save_strategy. * Added SaveStrategy and made according changes. `save_strategy` previously followed `IntervalStrategy` but now follows `SaveStrategy`. Changes were made accordingly to the code and the docstring. * Changes from `make fixup`. * Removed redundant metrics argument. * Added new test_save_best_checkpoint test. 1. Checks for both cases where `metric_for_best_model` is explicitly provided and when it's not provided. 2. The first case should have two checkpoints saved, whereas the second should have three saved. * Changed should_training_end saving logic. The Trainer saves a checkpoints at the end of training by default as long as `save_strategy != SaveStrategy.NO`. This condition was modified to include `SaveStrategy.BEST` because it would be counterintuitive that we'd only want the best checkpoint to be saved but the last one is as well. * `args.metric_for_best_model` default to loss. * Undo metric_for_best_model update. * Remove checking metric_for_best_model. * Added test cases for loss and no metric. * Added error for metric and changed default best_metric. * Removed unused import. * `new_best_metric` -> `is_new_best_metric` Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Applied `is_new_best_metric` to all. Changes were made for consistency and also to fix a potential bug. --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
committed by
GitHub
parent
8b3b9b48fc
commit
c1753436db
@@ -117,9 +117,9 @@ from .trainer_utils import (
|
|||||||
EvalPrediction,
|
EvalPrediction,
|
||||||
HPSearchBackend,
|
HPSearchBackend,
|
||||||
HubStrategy,
|
HubStrategy,
|
||||||
IntervalStrategy,
|
|
||||||
PredictionOutput,
|
PredictionOutput,
|
||||||
RemoveColumnsCollator,
|
RemoveColumnsCollator,
|
||||||
|
SaveStrategy,
|
||||||
TrainerMemoryTracker,
|
TrainerMemoryTracker,
|
||||||
TrainOutput,
|
TrainOutput,
|
||||||
check_target_module_exists,
|
check_target_module_exists,
|
||||||
@@ -419,6 +419,12 @@ class Trainer:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
|
f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
|
||||||
)
|
)
|
||||||
|
if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end:
|
||||||
|
if args.metric_for_best_model is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`."
|
||||||
|
)
|
||||||
|
|
||||||
self.args = args
|
self.args = args
|
||||||
self.compute_loss_func = compute_loss_func
|
self.compute_loss_func = compute_loss_func
|
||||||
# Seed must be set before instantiating the model when using model
|
# Seed must be set before instantiating the model when using model
|
||||||
@@ -2998,9 +3004,13 @@ class Trainer:
|
|||||||
metrics = None
|
metrics = None
|
||||||
if self.control.should_evaluate:
|
if self.control.should_evaluate:
|
||||||
metrics = self._evaluate(trial, ignore_keys_for_eval)
|
metrics = self._evaluate(trial, ignore_keys_for_eval)
|
||||||
|
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
|
||||||
|
|
||||||
|
if self.args.save_strategy == SaveStrategy.BEST:
|
||||||
|
self.control.should_save = is_new_best_metric
|
||||||
|
|
||||||
if self.control.should_save:
|
if self.control.should_save:
|
||||||
self._save_checkpoint(model, trial, metrics=metrics)
|
self._save_checkpoint(model, trial)
|
||||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||||
|
|
||||||
def _load_rng_state(self, checkpoint):
|
def _load_rng_state(self, checkpoint):
|
||||||
@@ -3077,7 +3087,48 @@ class Trainer:
|
|||||||
"\nThis won't yield the same results as if the training had not been interrupted."
|
"\nThis won't yield the same results as if the training had not been interrupted."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _save_checkpoint(self, model, trial, metrics=None):
|
def _determine_best_metric(self, metrics, trial):
|
||||||
|
"""
|
||||||
|
Determine if the model should be saved based on the evaluation metrics.
|
||||||
|
If args.metric_for_best_model is not set, the loss is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if a new best metric was found, else False
|
||||||
|
"""
|
||||||
|
is_new_best_metric = False
|
||||||
|
|
||||||
|
if self.args.metric_for_best_model is not None:
|
||||||
|
metric_to_check = self.args.metric_for_best_model
|
||||||
|
|
||||||
|
if not metric_to_check.startswith("eval_"):
|
||||||
|
metric_to_check = f"eval_{metric_to_check}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
metric_value = metrics[metric_to_check]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError(
|
||||||
|
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
|
||||||
|
f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
operator = np.greater if self.args.greater_is_better else np.less
|
||||||
|
|
||||||
|
if self.state.best_metric is None:
|
||||||
|
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
|
||||||
|
|
||||||
|
if operator(metric_value, self.state.best_metric):
|
||||||
|
run_dir = self._get_output_dir(trial=trial)
|
||||||
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
|
|
||||||
|
self.state.best_metric = metric_value
|
||||||
|
self.state.best_model_checkpoint = output_dir
|
||||||
|
|
||||||
|
is_new_best_metric = True
|
||||||
|
|
||||||
|
return is_new_best_metric
|
||||||
|
|
||||||
|
def _save_checkpoint(self, model, trial):
|
||||||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||||
# want to save except FullyShardedDDP.
|
# want to save except FullyShardedDDP.
|
||||||
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
|
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
|
||||||
@@ -3098,31 +3149,6 @@ class Trainer:
|
|||||||
# Save RNG state
|
# Save RNG state
|
||||||
self._save_rng_state(output_dir)
|
self._save_rng_state(output_dir)
|
||||||
|
|
||||||
# Determine the new best metric / best model checkpoint
|
|
||||||
if metrics is not None and self.args.metric_for_best_model is not None:
|
|
||||||
metric_to_check = self.args.metric_for_best_model
|
|
||||||
if not metric_to_check.startswith("eval_"):
|
|
||||||
metric_to_check = f"eval_{metric_to_check}"
|
|
||||||
try:
|
|
||||||
metric_value = metrics[metric_to_check]
|
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError(
|
|
||||||
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', "
|
|
||||||
f"which is not found in the evaluation metrics. "
|
|
||||||
f"The available evaluation metrics are: {list(metrics.keys())}. "
|
|
||||||
f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or "
|
|
||||||
f"consider changing the `metric_for_best_model` via the TrainingArguments."
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
operator = np.greater if self.args.greater_is_better else np.less
|
|
||||||
if (
|
|
||||||
self.state.best_metric is None
|
|
||||||
or self.state.best_model_checkpoint is None
|
|
||||||
or operator(metric_value, self.state.best_metric)
|
|
||||||
):
|
|
||||||
self.state.best_metric = metric_value
|
|
||||||
self.state.best_model_checkpoint = output_dir
|
|
||||||
|
|
||||||
# Save the Trainer state
|
# Save the Trainer state
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
|
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
|
||||||
@@ -4543,7 +4569,7 @@ class Trainer:
|
|||||||
# Same for the training arguments
|
# Same for the training arguments
|
||||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||||
|
|
||||||
if self.args.save_strategy == IntervalStrategy.STEPS:
|
if self.args.save_strategy == SaveStrategy.STEPS:
|
||||||
commit_message = f"Training in progress, step {self.state.global_step}"
|
commit_message = f"Training in progress, step {self.state.global_step}"
|
||||||
else:
|
else:
|
||||||
commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
|
commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from .trainer_utils import IntervalStrategy, has_length
|
from .trainer_utils import IntervalStrategy, SaveStrategy, has_length
|
||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
@@ -555,7 +555,7 @@ class DefaultFlowCallback(TrainerCallback):
|
|||||||
|
|
||||||
# Save
|
# Save
|
||||||
if (
|
if (
|
||||||
args.save_strategy == IntervalStrategy.STEPS
|
args.save_strategy == SaveStrategy.STEPS
|
||||||
and state.save_steps > 0
|
and state.save_steps > 0
|
||||||
and state.global_step % state.save_steps == 0
|
and state.global_step % state.save_steps == 0
|
||||||
):
|
):
|
||||||
@@ -565,7 +565,7 @@ class DefaultFlowCallback(TrainerCallback):
|
|||||||
if state.global_step >= state.max_steps:
|
if state.global_step >= state.max_steps:
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
# Save the model at the end if we have a save strategy
|
# Save the model at the end if we have a save strategy
|
||||||
if args.save_strategy != IntervalStrategy.NO:
|
if args.save_strategy not in [SaveStrategy.NO, SaveStrategy.BEST]:
|
||||||
control.should_save = True
|
control.should_save = True
|
||||||
|
|
||||||
return control
|
return control
|
||||||
@@ -580,7 +580,7 @@ class DefaultFlowCallback(TrainerCallback):
|
|||||||
control.should_evaluate = True
|
control.should_evaluate = True
|
||||||
|
|
||||||
# Save
|
# Save
|
||||||
if args.save_strategy == IntervalStrategy.EPOCH:
|
if args.save_strategy == SaveStrategy.EPOCH:
|
||||||
control.should_save = True
|
control.should_save = True
|
||||||
|
|
||||||
return control
|
return control
|
||||||
|
|||||||
@@ -227,6 +227,13 @@ class IntervalStrategy(ExplicitEnum):
|
|||||||
EPOCH = "epoch"
|
EPOCH = "epoch"
|
||||||
|
|
||||||
|
|
||||||
|
class SaveStrategy(ExplicitEnum):
|
||||||
|
NO = "no"
|
||||||
|
STEPS = "steps"
|
||||||
|
EPOCH = "epoch"
|
||||||
|
BEST = "best"
|
||||||
|
|
||||||
|
|
||||||
class EvaluationStrategy(ExplicitEnum):
|
class EvaluationStrategy(ExplicitEnum):
|
||||||
NO = "no"
|
NO = "no"
|
||||||
STEPS = "steps"
|
STEPS = "steps"
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from .trainer_utils import (
|
|||||||
FSDPOption,
|
FSDPOption,
|
||||||
HubStrategy,
|
HubStrategy,
|
||||||
IntervalStrategy,
|
IntervalStrategy,
|
||||||
|
SaveStrategy,
|
||||||
SchedulerType,
|
SchedulerType,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@@ -349,12 +350,13 @@ class TrainingArguments:
|
|||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
|
save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`):
|
||||||
The checkpoint save strategy to adopt during training. Possible values are:
|
The checkpoint save strategy to adopt during training. Possible values are:
|
||||||
|
|
||||||
- `"no"`: No save is done during training.
|
- `"no"`: No save is done during training.
|
||||||
- `"epoch"`: Save is done at the end of each epoch.
|
- `"epoch"`: Save is done at the end of each epoch.
|
||||||
- `"steps"`: Save is done every `save_steps`.
|
- `"steps"`: Save is done every `save_steps`.
|
||||||
|
- `"best"`: Save is done whenever a new `best_metric` is achieved.
|
||||||
|
|
||||||
If `"epoch"` or `"steps"` is chosen, saving will also be performed at the
|
If `"epoch"` or `"steps"` is chosen, saving will also be performed at the
|
||||||
very end of training, always.
|
very end of training, always.
|
||||||
@@ -962,7 +964,7 @@ class TrainingArguments:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
|
logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
|
||||||
save_strategy: Union[IntervalStrategy, str] = field(
|
save_strategy: Union[SaveStrategy, str] = field(
|
||||||
default="steps",
|
default="steps",
|
||||||
metadata={"help": "The checkpoint save strategy to use."},
|
metadata={"help": "The checkpoint save strategy to use."},
|
||||||
)
|
)
|
||||||
@@ -1580,7 +1582,7 @@ class TrainingArguments:
|
|||||||
|
|
||||||
self.eval_strategy = IntervalStrategy(self.eval_strategy)
|
self.eval_strategy = IntervalStrategy(self.eval_strategy)
|
||||||
self.logging_strategy = IntervalStrategy(self.logging_strategy)
|
self.logging_strategy = IntervalStrategy(self.logging_strategy)
|
||||||
self.save_strategy = IntervalStrategy(self.save_strategy)
|
self.save_strategy = SaveStrategy(self.save_strategy)
|
||||||
self.hub_strategy = HubStrategy(self.hub_strategy)
|
self.hub_strategy = HubStrategy(self.hub_strategy)
|
||||||
|
|
||||||
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
|
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
|
||||||
@@ -1616,7 +1618,7 @@ class TrainingArguments:
|
|||||||
if self.eval_steps != int(self.eval_steps):
|
if self.eval_steps != int(self.eval_steps):
|
||||||
raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}")
|
raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}")
|
||||||
self.eval_steps = int(self.eval_steps)
|
self.eval_steps = int(self.eval_steps)
|
||||||
if self.save_strategy == IntervalStrategy.STEPS and self.save_steps > 1:
|
if self.save_strategy == SaveStrategy.STEPS and self.save_steps > 1:
|
||||||
if self.save_steps != int(self.save_steps):
|
if self.save_steps != int(self.save_steps):
|
||||||
raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}")
|
raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}")
|
||||||
self.save_steps = int(self.save_steps)
|
self.save_steps = int(self.save_steps)
|
||||||
@@ -2750,8 +2752,8 @@ class TrainingArguments:
|
|||||||
100
|
100
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
self.save_strategy = IntervalStrategy(strategy)
|
self.save_strategy = SaveStrategy(strategy)
|
||||||
if self.save_strategy == IntervalStrategy.STEPS and steps == 0:
|
if self.save_strategy == SaveStrategy.STEPS and steps == 0:
|
||||||
raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.")
|
raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.")
|
||||||
self.save_steps = steps
|
self.save_steps = steps
|
||||||
self.save_total_limit = total_limit
|
self.save_total_limit = total_limit
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
Whether to log and evaluate the first `global_step` or not.
|
Whether to log and evaluate the first `global_step` or not.
|
||||||
logging_steps (`int`, *optional*, defaults to 500):
|
logging_steps (`int`, *optional*, defaults to 500):
|
||||||
Number of update steps between two logs if `logging_strategy="steps"`.
|
Number of update steps between two logs if `logging_strategy="steps"`.
|
||||||
save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
|
save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`):
|
||||||
The checkpoint save strategy to adopt during training. Possible values are:
|
The checkpoint save strategy to adopt during training. Possible values are:
|
||||||
|
|
||||||
- `"no"`: No save is done during training.
|
- `"no"`: No save is done during training.
|
||||||
|
|||||||
@@ -4041,6 +4041,89 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
reloaded_tokenizer(test_sentence, padding="max_length").input_ids,
|
reloaded_tokenizer(test_sentence, padding="max_length").input_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_save_best_checkpoint(self):
|
||||||
|
freq = int(64 / self.batch_size)
|
||||||
|
total = int(self.n_epochs * 64 / self.batch_size)
|
||||||
|
|
||||||
|
# Case 1: args.metric_for_best_model == "accuracy".
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5,
|
||||||
|
b=2.5,
|
||||||
|
output_dir=tmpdir,
|
||||||
|
learning_rate=0.1,
|
||||||
|
eval_strategy="epoch",
|
||||||
|
save_strategy="best",
|
||||||
|
metric_for_best_model="accuracy",
|
||||||
|
compute_metrics=AlmostAccuracy(),
|
||||||
|
)
|
||||||
|
self.assertTrue(trainer.args.metric_for_best_model == "accuracy")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
trainer,
|
||||||
|
"_evaluate",
|
||||||
|
side_effect=[
|
||||||
|
{"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0},
|
||||||
|
{"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0},
|
||||||
|
{"eval_loss": 0.01, "eval_accuracy": 0.64, "epoch": 3.0},
|
||||||
|
],
|
||||||
|
):
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
self.assertEqual(len(os.listdir(tmpdir)), 2)
|
||||||
|
self.check_saved_checkpoints(
|
||||||
|
output_dir=tmpdir,
|
||||||
|
freq=freq,
|
||||||
|
total=total,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Case 2: args.metric_for_best_model == "loss".
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5,
|
||||||
|
b=2.5,
|
||||||
|
output_dir=tmpdir,
|
||||||
|
learning_rate=0.1,
|
||||||
|
eval_strategy="epoch",
|
||||||
|
save_strategy="best",
|
||||||
|
metric_for_best_model="loss",
|
||||||
|
compute_metrics=AlmostAccuracy(),
|
||||||
|
)
|
||||||
|
self.assertTrue(trainer.args.metric_for_best_model == "loss")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
trainer,
|
||||||
|
"_evaluate",
|
||||||
|
side_effect=[
|
||||||
|
{"eval_loss": 0.03, "eval_accuracy": 0.60, "epoch": 1.0},
|
||||||
|
{"eval_loss": 0.02, "eval_accuracy": 0.65, "epoch": 2.0},
|
||||||
|
{"eval_loss": 0.03, "eval_accuracy": 0.66, "epoch": 3.0},
|
||||||
|
],
|
||||||
|
):
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
self.assertEqual(len(os.listdir(tmpdir)), 2)
|
||||||
|
self.check_saved_checkpoints(
|
||||||
|
output_dir=tmpdir,
|
||||||
|
freq=freq,
|
||||||
|
total=total,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Case 3: Metric name not provided; throw error.
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5,
|
||||||
|
b=2.5,
|
||||||
|
output_dir=tmpdir,
|
||||||
|
learning_rate=0.1,
|
||||||
|
eval_strategy="epoch",
|
||||||
|
save_strategy="best",
|
||||||
|
compute_metrics=AlmostAccuracy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("`args.metric_for_best_model` must be provided", str(context.exception))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user