From 89edf504bf2abdb7d53a3b7770fc77e7849e0ab8 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 22 Sep 2020 09:52:29 -0400 Subject: [PATCH] Add possibility to evaluate every epoch (#7302) * Add possibility to evaluate every epoch * Remove multitype arg * Remove needless import * Use a proper enum * Apply suggestions from @LysandreJik Co-authored-by: Lysandre Debut * One else and formatting Co-authored-by: Lysandre Debut --- src/transformers/trainer.py | 9 +++++- src/transformers/trainer_utils.py | 6 ++++ src/transformers/training_args.py | 52 ++++++++++++++++++++++++++----- 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a58ac0a427..de00161050 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -39,6 +39,7 @@ from .trainer_utils import ( PREFIX_CHECKPOINT_DIR, BestRun, EvalPrediction, + EvaluationStrategy, HPSearchBackend, PredictionOutput, TrainOutput, @@ -782,7 +783,10 @@ class Trainer: self.log(logs) - if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0: + if ( + self.args.evaluation_strategy == EvaluationStrategy.STEPS + and self.global_step % self.args.eval_steps == 0 + ): metrics = self.evaluate() self._report_to_hp_search(trial, epoch, metrics) @@ -820,6 +824,9 @@ class Trainer: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) epoch_pbar.update(1) + if self.args.evaluation_strategy == EvaluationStrategy.EPOCH: + metrics = self.evaluate() + self._report_to_hp_search(trial, epoch, metrics) if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: break epoch_pbar.close() diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index e273b00aa8..76e69ad559 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -60,6 +60,12 @@ class TrainOutput(NamedTuple): PREFIX_CHECKPOINT_DIR = "checkpoint" +class EvaluationStrategy(ExplicitEnum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + + class BestRun(NamedTuple): """ The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`). diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c635595f2b..1000f09144 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1,10 +1,13 @@ import dataclasses import json import os +import warnings from dataclasses import dataclass, field +from enum import Enum from typing import Any, Dict, List, Optional, Tuple from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required +from .trainer_utils import EvaluationStrategy from .utils import logging @@ -50,8 +53,13 @@ class TrainingArguments: Whether to run evaluation on the dev set or not. do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to run predictions on the test set or not. - evaluate_during_training (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to run evaluation during training at each logging step or not. + evaluation_strategy(:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`): + The evaluation strategy to adopt during training. Possible values are: + + * :obj:`"no"`: No evaluation is done during training. + * :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`. + * :obj:`"epoch"`: Evaluation is done at the end of each epoch. + prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`): When performing evaluation and predictions, only returns the loss. per_device_train_batch_size (:obj:`int`, `optional`, defaults to 8): @@ -111,8 +119,9 @@ class TrainingArguments: dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) or not. - eval_steps (:obj:`int`, `optional`, defaults to 1000): - Number of update steps between two evaluations. + eval_steps (:obj:`int`, `optional`): + Number of update steps between two evaluations if :obj:`evaluation_strategy="steps"`. Will default to the + same value as :obj:`logging_steps` if not set. past_index (:obj:`int`, `optional`, defaults to -1): Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can make use of the past hidden states for their predictions. If this argument is set to a positive int, the @@ -153,7 +162,11 @@ class TrainingArguments: do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) evaluate_during_training: bool = field( - default=False, + default=None, + metadata={"help": "Run evaluation during training at each logging step."}, + ) + evaluation_strategy: EvaluationStrategy = field( + default="no", metadata={"help": "Run evaluation during training at each logging step."}, ) prediction_loss_only: bool = field( @@ -245,7 +258,7 @@ class TrainingArguments: dataloader_drop_last: bool = field( default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} ) - eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."}) + eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) past_index: int = field( default=-1, @@ -269,6 +282,19 @@ class TrainingArguments: def __post_init__(self): if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN + if self.evaluate_during_training is not None: + self.evaluation_strategy = ( + EvaluationStrategy.STEPS if self.evaluate_during_training else EvaluationStrategy.NO + ) + warnings.warn( + "The `evaluate_during_training` argument is deprecated in favor of `evaluation_strategy` (which has more options)", + FutureWarning, + ) + else: + self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy) + + if self.eval_steps is None: + self.eval_steps = self.logging_steps @property def train_batch_size(self) -> int: @@ -347,17 +373,27 @@ class TrainingArguments: """ return self._setup_devices[1] + def to_dict(self): + """ + Serializes this instance while replace `Enum` by their values (for JSON serialization support). + """ + d = dataclasses.asdict(self) + for k, v in d.items(): + if isinstance(v, Enum): + d[k] = v.value + return d + def to_json_string(self): """ Serializes this instance to a JSON string. """ - return json.dumps(dataclasses.asdict(self), indent=2) + return json.dumps(self.to_dict(), indent=2) def to_sanitized_dict(self) -> Dict[str, Any]: """ Sanitized serialization to use with TensorBoard’s hparams """ - d = dataclasses.asdict(self) + d = self.to_dict() d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}} valid_types = [bool, int, float, str]