Introduce save_strategy training argument (#10286)
* Introduce save_strategy training argument * deprecate EvaluationStrategy * collapse EvaluationStrategy and LoggingStrategy into a single IntervalStrategy enum * modify tests to use modified enum
This commit is contained in:
@@ -255,7 +255,7 @@ _import_structure = {
|
||||
"TrainerControl",
|
||||
"TrainerState",
|
||||
],
|
||||
"trainer_utils": ["EvalPrediction", "EvaluationStrategy", "SchedulerType", "set_seed"],
|
||||
"trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "set_seed"],
|
||||
"training_args": ["TrainingArguments"],
|
||||
"training_args_seq2seq": ["Seq2SeqTrainingArguments"],
|
||||
"training_args_tf": ["TFTrainingArguments"],
|
||||
@@ -1429,7 +1429,7 @@ if TYPE_CHECKING:
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
)
|
||||
from .trainer_utils import EvalPrediction, EvaluationStrategy, SchedulerType, set_seed
|
||||
from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, set_seed
|
||||
from .training_args import TrainingArguments
|
||||
from .training_args_seq2seq import Seq2SeqTrainingArguments
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
|
||||
@@ -48,7 +48,7 @@ if _has_comet:
|
||||
|
||||
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
|
||||
from .trainer_callback import TrainerCallback # noqa: E402
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, EvaluationStrategy # noqa: E402
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
|
||||
|
||||
|
||||
# Integration functions:
|
||||
@@ -219,7 +219,7 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
||||
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
|
||||
if isinstance(
|
||||
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
|
||||
) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == EvaluationStrategy.NO):
|
||||
) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == IntervalStrategy.NO):
|
||||
raise RuntimeError(
|
||||
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
|
||||
"This means your trials will not report intermediate results to Ray Tune, and "
|
||||
|
||||
@@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .trainer_utils import EvaluationStrategy, LoggingStrategy
|
||||
from .trainer_utils import IntervalStrategy
|
||||
from .training_args import TrainingArguments
|
||||
from .utils import logging
|
||||
|
||||
@@ -404,20 +404,25 @@ class DefaultFlowCallback(TrainerCallback):
|
||||
if state.global_step == 1 and args.logging_first_step:
|
||||
control.should_log = True
|
||||
if (
|
||||
args.logging_strategy == LoggingStrategy.STEPS
|
||||
args.logging_strategy == IntervalStrategy.STEPS
|
||||
and args.logging_steps > 0
|
||||
and state.global_step % args.logging_steps == 0
|
||||
):
|
||||
control.should_log = True
|
||||
|
||||
# Evaluate
|
||||
if args.evaluation_strategy == EvaluationStrategy.STEPS and state.global_step % args.eval_steps == 0:
|
||||
if args.evaluation_strategy == IntervalStrategy.STEPS and state.global_step % args.eval_steps == 0:
|
||||
control.should_evaluate = True
|
||||
if args.load_best_model_at_end:
|
||||
control.should_save = True
|
||||
|
||||
# Save
|
||||
if not args.load_best_model_at_end and args.save_steps > 0 and state.global_step % args.save_steps == 0:
|
||||
if (
|
||||
not args.load_best_model_at_end
|
||||
and args.save_strategy == IntervalStrategy.STEPS
|
||||
and args.save_steps > 0
|
||||
and state.global_step % args.save_steps == 0
|
||||
):
|
||||
control.should_save = True
|
||||
|
||||
# End training
|
||||
@@ -428,14 +433,19 @@ class DefaultFlowCallback(TrainerCallback):
|
||||
|
||||
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
# Log
|
||||
if args.logging_strategy == LoggingStrategy.EPOCH:
|
||||
if args.logging_strategy == IntervalStrategy.EPOCH:
|
||||
control.should_log = True
|
||||
|
||||
# Evaluate
|
||||
if args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||
if args.evaluation_strategy == IntervalStrategy.EPOCH:
|
||||
control.should_evaluate = True
|
||||
if args.load_best_model_at_end:
|
||||
control.should_save = True
|
||||
|
||||
# Save
|
||||
if args.save_strategy == IntervalStrategy.EPOCH:
|
||||
control.should_save = True
|
||||
|
||||
return control
|
||||
|
||||
|
||||
@@ -531,8 +541,8 @@ class EarlyStoppingCallback(TrainerCallback):
|
||||
args.metric_for_best_model is not None
|
||||
), "EarlyStoppingCallback requires metric_for_best_model is defined"
|
||||
assert (
|
||||
args.evaluation_strategy != EvaluationStrategy.NO
|
||||
), "EarlyStoppingCallback requires EvaluationStrategy of steps or epoch"
|
||||
args.evaluation_strategy != IntervalStrategy.NO
|
||||
), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch"
|
||||
|
||||
def on_evaluate(self, args, state, control, metrics, **kwargs):
|
||||
metric_to_check = args.metric_for_best_model
|
||||
|
||||
@@ -33,7 +33,7 @@ from tensorflow.python.distribute.values import PerReplica
|
||||
|
||||
from .modeling_tf_utils import TFPreTrainedModel
|
||||
from .optimization_tf import GradientAccumulator, create_optimizer
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, EvaluationStrategy, PredictionOutput, set_seed
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, IntervalStrategy, PredictionOutput, set_seed
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
from .utils import logging
|
||||
|
||||
@@ -574,7 +574,7 @@ class TFTrainer:
|
||||
|
||||
if (
|
||||
self.args.eval_steps > 0
|
||||
and self.args.evaluation_strategy == EvaluationStrategy.STEPS
|
||||
and self.args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and self.global_step % self.args.eval_steps == 0
|
||||
):
|
||||
self.evaluate()
|
||||
|
||||
@@ -101,13 +101,13 @@ def get_last_checkpoint(folder):
|
||||
return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
|
||||
|
||||
|
||||
class EvaluationStrategy(ExplicitEnum):
|
||||
class IntervalStrategy(ExplicitEnum):
|
||||
NO = "no"
|
||||
STEPS = "steps"
|
||||
EPOCH = "epoch"
|
||||
|
||||
|
||||
class LoggingStrategy(ExplicitEnum):
|
||||
class EvaluationStrategy(ExplicitEnum):
|
||||
NO = "no"
|
||||
STEPS = "steps"
|
||||
EPOCH = "epoch"
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -25,7 +26,7 @@ from .file_utils import (
|
||||
is_torch_tpu_available,
|
||||
torch_required,
|
||||
)
|
||||
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption
|
||||
from .trainer_utils import EvaluationStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
|
||||
from .utils import logging
|
||||
|
||||
|
||||
@@ -84,7 +85,7 @@ class TrainingArguments:
|
||||
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
||||
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
||||
details.
|
||||
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
|
||||
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
|
||||
The evaluation strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No evaluation is done during training.
|
||||
@@ -139,7 +140,7 @@ class TrainingArguments:
|
||||
logging_dir (:obj:`str`, `optional`):
|
||||
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
||||
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
||||
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
The logging strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No logging is done during training.
|
||||
@@ -150,8 +151,15 @@ class TrainingArguments:
|
||||
Whether to log and evaluate the first :obj:`global_step` or not.
|
||||
logging_steps (:obj:`int`, `optional`, defaults to 500):
|
||||
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
|
||||
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
The checkpoint save strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No save is done during training.
|
||||
* :obj:`"epoch"`: Save is done at the end of each epoch.
|
||||
* :obj:`"steps"`: Save is done every :obj:`save_steps`.
|
||||
|
||||
save_steps (:obj:`int`, `optional`, defaults to 500):
|
||||
Number of updates steps before two checkpoint saves.
|
||||
Number of updates steps before two checkpoint saves if :obj:`save_strategy="steps"`.
|
||||
save_total_limit (:obj:`int`, `optional`):
|
||||
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
|
||||
:obj:`output_dir`.
|
||||
@@ -215,8 +223,8 @@ class TrainingArguments:
|
||||
|
||||
.. note::
|
||||
|
||||
When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved
|
||||
after each evaluation.
|
||||
When set to :obj:`True`, the parameters :obj:`save_strategy` and :obj:`save_steps` will be ignored and
|
||||
the model will be saved after each evaluation.
|
||||
metric_for_best_model (:obj:`str`, `optional`):
|
||||
Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different
|
||||
models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`.
|
||||
@@ -297,7 +305,7 @@ class TrainingArguments:
|
||||
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||
do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
|
||||
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
||||
evaluation_strategy: EvaluationStrategy = field(
|
||||
evaluation_strategy: IntervalStrategy = field(
|
||||
default="no",
|
||||
metadata={"help": "The evaluation strategy to use."},
|
||||
)
|
||||
@@ -359,12 +367,16 @@ class TrainingArguments:
|
||||
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||
|
||||
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
||||
logging_strategy: LoggingStrategy = field(
|
||||
logging_strategy: IntervalStrategy = field(
|
||||
default="steps",
|
||||
metadata={"help": "The logging strategy to use."},
|
||||
)
|
||||
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
|
||||
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||
save_strategy: IntervalStrategy = field(
|
||||
default="steps",
|
||||
metadata={"help": "The checkpoint save strategy to use."},
|
||||
)
|
||||
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||
save_total_limit: Optional[int] = field(
|
||||
default=None,
|
||||
@@ -510,10 +522,19 @@ class TrainingArguments:
|
||||
self.output_dir = os.getenv("SM_OUTPUT_DATA_DIR")
|
||||
if self.disable_tqdm is None:
|
||||
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
|
||||
self.logging_strategy = LoggingStrategy(self.logging_strategy)
|
||||
|
||||
if isinstance(self.evaluation_strategy, EvaluationStrategy):
|
||||
warnings.warn(
|
||||
"using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `IntervalStrategy` instead",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
|
||||
self.logging_strategy = IntervalStrategy(self.logging_strategy)
|
||||
self.save_strategy = IntervalStrategy(self.save_strategy)
|
||||
|
||||
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
|
||||
if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO:
|
||||
if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
|
||||
self.do_eval = True
|
||||
if self.eval_steps is None:
|
||||
self.eval_steps = self.logging_steps
|
||||
|
||||
@@ -58,7 +58,7 @@ class TFTrainingArguments(TrainingArguments):
|
||||
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
||||
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
||||
details.
|
||||
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
|
||||
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
|
||||
The evaluation strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No evaluation is done during training.
|
||||
@@ -102,7 +102,7 @@ class TFTrainingArguments(TrainingArguments):
|
||||
logging_dir (:obj:`str`, `optional`):
|
||||
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
||||
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
||||
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
The logging strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No logging is done during training.
|
||||
@@ -113,8 +113,15 @@ class TFTrainingArguments(TrainingArguments):
|
||||
Whether to log and evaluate the first :obj:`global_step` or not.
|
||||
logging_steps (:obj:`int`, `optional`, defaults to 500):
|
||||
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
|
||||
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
The checkpoint save strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No save is done during training.
|
||||
* :obj:`"epoch"`: Save is done at the end of each epoch.
|
||||
* :obj:`"steps"`: Save is done every :obj:`save_steps`.
|
||||
|
||||
save_steps (:obj:`int`, `optional`, defaults to 500):
|
||||
Number of updates steps before two checkpoint saves.
|
||||
Number of updates steps before two checkpoint saves if :obj:`save_strategy="steps"`.
|
||||
save_total_limit (:obj:`int`, `optional`):
|
||||
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
|
||||
:obj:`output_dir`.
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import Optional
|
||||
import IPython.display as disp
|
||||
|
||||
from ..trainer_callback import TrainerCallback
|
||||
from ..trainer_utils import EvaluationStrategy
|
||||
from ..trainer_utils import IntervalStrategy
|
||||
|
||||
|
||||
def format_time(t):
|
||||
@@ -277,11 +277,11 @@ class NotebookProgressCallback(TrainerCallback):
|
||||
self._force_next_update = False
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
self.first_column = "Epoch" if args.evaluation_strategy == EvaluationStrategy.EPOCH else "Step"
|
||||
self.first_column = "Epoch" if args.evaluation_strategy == IntervalStrategy.EPOCH else "Step"
|
||||
self.training_loss = 0
|
||||
self.last_log = 0
|
||||
column_names = [self.first_column] + ["Training Loss"]
|
||||
if args.evaluation_strategy != EvaluationStrategy.NO:
|
||||
if args.evaluation_strategy != IntervalStrategy.NO:
|
||||
column_names.append("Validation Loss")
|
||||
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
|
||||
|
||||
@@ -306,7 +306,7 @@ class NotebookProgressCallback(TrainerCallback):
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
# Only for when there is no evaluation
|
||||
if args.evaluation_strategy == EvaluationStrategy.NO and "loss" in logs:
|
||||
if args.evaluation_strategy == IntervalStrategy.NO and "loss" in logs:
|
||||
values = {"Training Loss": logs["loss"]}
|
||||
# First column is necessarily Step sine we're not in epoch eval strategy
|
||||
values["Step"] = state.global_step
|
||||
|
||||
Reference in New Issue
Block a user