Introduce save_strategy training argument (#10286)

* Introduce save_strategy training argument

* deprecate EvaluationStrategy

* collapse EvaluationStrategy and LoggingStrategy into a single
  IntervalStrategy enum

* modify tests to use modified enum
This commit is contained in:
Tanmay Garg
2021-02-28 06:04:22 +05:30
committed by GitHub
parent aca6288ff4
commit 256482ac92
11 changed files with 81 additions and 46 deletions

View File

@@ -14,6 +14,7 @@
import json
import os
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
@@ -25,7 +26,7 @@ from .file_utils import (
is_torch_tpu_available,
torch_required,
)
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption
from .trainer_utils import EvaluationStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
from .utils import logging
@@ -84,7 +85,7 @@ class TrainingArguments:
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details.
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
The evaluation strategy to adopt during training. Possible values are:
* :obj:`"no"`: No evaluation is done during training.
@@ -139,7 +140,7 @@ class TrainingArguments:
logging_dir (:obj:`str`, `optional`):
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
`runs/**CURRENT_DATETIME_HOSTNAME**`.
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`):
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
The logging strategy to adopt during training. Possible values are:
* :obj:`"no"`: No logging is done during training.
@@ -150,8 +151,15 @@ class TrainingArguments:
Whether to log and evaluate the first :obj:`global_step` or not.
logging_steps (:obj:`int`, `optional`, defaults to 500):
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
The checkpoint save strategy to adopt during training. Possible values are:
* :obj:`"no"`: No save is done during training.
* :obj:`"epoch"`: Save is done at the end of each epoch.
* :obj:`"steps"`: Save is done every :obj:`save_steps`.
save_steps (:obj:`int`, `optional`, defaults to 500):
Number of updates steps before two checkpoint saves.
Number of updates steps before two checkpoint saves if :obj:`save_strategy="steps"`.
save_total_limit (:obj:`int`, `optional`):
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
:obj:`output_dir`.
@@ -215,8 +223,8 @@ class TrainingArguments:
.. note::
When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved
after each evaluation.
When set to :obj:`True`, the parameters :obj:`save_strategy` and :obj:`save_steps` will be ignored and
the model will be saved after each evaluation.
metric_for_best_model (:obj:`str`, `optional`):
Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different
models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`.
@@ -297,7 +305,7 @@ class TrainingArguments:
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluation_strategy: EvaluationStrategy = field(
evaluation_strategy: IntervalStrategy = field(
default="no",
metadata={"help": "The evaluation strategy to use."},
)
@@ -359,12 +367,16 @@ class TrainingArguments:
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
logging_strategy: LoggingStrategy = field(
logging_strategy: IntervalStrategy = field(
default="steps",
metadata={"help": "The logging strategy to use."},
)
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_strategy: IntervalStrategy = field(
default="steps",
metadata={"help": "The checkpoint save strategy to use."},
)
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
save_total_limit: Optional[int] = field(
default=None,
@@ -510,10 +522,19 @@ class TrainingArguments:
self.output_dir = os.getenv("SM_OUTPUT_DATA_DIR")
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
self.logging_strategy = LoggingStrategy(self.logging_strategy)
if isinstance(self.evaluation_strategy, EvaluationStrategy):
warnings.warn(
"using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `IntervalStrategy` instead",
FutureWarning,
)
self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
self.logging_strategy = IntervalStrategy(self.logging_strategy)
self.save_strategy = IntervalStrategy(self.save_strategy)
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO:
if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
self.do_eval = True
if self.eval_steps is None:
self.eval_steps = self.logging_steps