Introduce save_strategy training argument (#10286)
* Introduce save_strategy training argument * deprecate EvaluationStrategy * collapse EvaluationStrategy and LoggingStrategy into a single IntervalStrategy enum * modify tests to use modified enum
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -25,7 +26,7 @@ from .file_utils import (
|
||||
is_torch_tpu_available,
|
||||
torch_required,
|
||||
)
|
||||
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption
|
||||
from .trainer_utils import EvaluationStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
|
||||
from .utils import logging
|
||||
|
||||
|
||||
@@ -84,7 +85,7 @@ class TrainingArguments:
|
||||
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
||||
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
||||
details.
|
||||
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
|
||||
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
|
||||
The evaluation strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No evaluation is done during training.
|
||||
@@ -139,7 +140,7 @@ class TrainingArguments:
|
||||
logging_dir (:obj:`str`, `optional`):
|
||||
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
||||
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
||||
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
The logging strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No logging is done during training.
|
||||
@@ -150,8 +151,15 @@ class TrainingArguments:
|
||||
Whether to log and evaluate the first :obj:`global_step` or not.
|
||||
logging_steps (:obj:`int`, `optional`, defaults to 500):
|
||||
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
|
||||
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
The checkpoint save strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No save is done during training.
|
||||
* :obj:`"epoch"`: Save is done at the end of each epoch.
|
||||
* :obj:`"steps"`: Save is done every :obj:`save_steps`.
|
||||
|
||||
save_steps (:obj:`int`, `optional`, defaults to 500):
|
||||
Number of updates steps before two checkpoint saves.
|
||||
Number of updates steps before two checkpoint saves if :obj:`save_strategy="steps"`.
|
||||
save_total_limit (:obj:`int`, `optional`):
|
||||
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
|
||||
:obj:`output_dir`.
|
||||
@@ -215,8 +223,8 @@ class TrainingArguments:
|
||||
|
||||
.. note::
|
||||
|
||||
When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved
|
||||
after each evaluation.
|
||||
When set to :obj:`True`, the parameters :obj:`save_strategy` and :obj:`save_steps` will be ignored and
|
||||
the model will be saved after each evaluation.
|
||||
metric_for_best_model (:obj:`str`, `optional`):
|
||||
Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different
|
||||
models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`.
|
||||
@@ -297,7 +305,7 @@ class TrainingArguments:
|
||||
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||
do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
|
||||
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
||||
evaluation_strategy: EvaluationStrategy = field(
|
||||
evaluation_strategy: IntervalStrategy = field(
|
||||
default="no",
|
||||
metadata={"help": "The evaluation strategy to use."},
|
||||
)
|
||||
@@ -359,12 +367,16 @@ class TrainingArguments:
|
||||
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||
|
||||
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
||||
logging_strategy: LoggingStrategy = field(
|
||||
logging_strategy: IntervalStrategy = field(
|
||||
default="steps",
|
||||
metadata={"help": "The logging strategy to use."},
|
||||
)
|
||||
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
|
||||
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||
save_strategy: IntervalStrategy = field(
|
||||
default="steps",
|
||||
metadata={"help": "The checkpoint save strategy to use."},
|
||||
)
|
||||
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||
save_total_limit: Optional[int] = field(
|
||||
default=None,
|
||||
@@ -510,10 +522,19 @@ class TrainingArguments:
|
||||
self.output_dir = os.getenv("SM_OUTPUT_DATA_DIR")
|
||||
if self.disable_tqdm is None:
|
||||
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
|
||||
self.logging_strategy = LoggingStrategy(self.logging_strategy)
|
||||
|
||||
if isinstance(self.evaluation_strategy, EvaluationStrategy):
|
||||
warnings.warn(
|
||||
"using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `IntervalStrategy` instead",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
|
||||
self.logging_strategy = IntervalStrategy(self.logging_strategy)
|
||||
self.save_strategy = IntervalStrategy(self.save_strategy)
|
||||
|
||||
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
|
||||
if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO:
|
||||
if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
|
||||
self.do_eval = True
|
||||
if self.eval_steps is None:
|
||||
self.eval_steps = self.logging_steps
|
||||
|
||||
Reference in New Issue
Block a user