Add possibility to evaluate every epoch (#7302)

* Add possibility to evaluate every epoch

* Remove multitype arg

* Remove needless import

* Use a proper enum

* Apply suggestions from @LysandreJik

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* One else and formatting

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2020-09-22 09:52:29 -04:00
committed by GitHub
parent 21ca148090
commit 89edf504bf
3 changed files with 58 additions and 9 deletions

View File

@@ -39,6 +39,7 @@ from .trainer_utils import (
PREFIX_CHECKPOINT_DIR, PREFIX_CHECKPOINT_DIR,
BestRun, BestRun,
EvalPrediction, EvalPrediction,
EvaluationStrategy,
HPSearchBackend, HPSearchBackend,
PredictionOutput, PredictionOutput,
TrainOutput, TrainOutput,
@@ -782,7 +783,10 @@ class Trainer:
self.log(logs) self.log(logs)
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0: if (
self.args.evaluation_strategy == EvaluationStrategy.STEPS
and self.global_step % self.args.eval_steps == 0
):
metrics = self.evaluate() metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics) self._report_to_hp_search(trial, epoch, metrics)
@@ -820,6 +824,9 @@ class Trainer:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
epoch_pbar.update(1) epoch_pbar.update(1)
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
break break
epoch_pbar.close() epoch_pbar.close()

View File

@@ -60,6 +60,12 @@ class TrainOutput(NamedTuple):
PREFIX_CHECKPOINT_DIR = "checkpoint" PREFIX_CHECKPOINT_DIR = "checkpoint"
class EvaluationStrategy(ExplicitEnum):
NO = "no"
STEPS = "steps"
EPOCH = "epoch"
class BestRun(NamedTuple): class BestRun(NamedTuple):
""" """
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`). The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).

View File

@@ -1,10 +1,13 @@
import dataclasses import dataclasses
import json import json
import os import os
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .trainer_utils import EvaluationStrategy
from .utils import logging from .utils import logging
@@ -50,8 +53,13 @@ class TrainingArguments:
Whether to run evaluation on the dev set or not. Whether to run evaluation on the dev set or not.
do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`): do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to run predictions on the test set or not. Whether to run predictions on the test set or not.
evaluate_during_training (:obj:`bool`, `optional`, defaults to :obj:`False`): evaluation_strategy(:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
Whether to run evaluation during training at each logging step or not. The evaluation strategy to adopt during training. Possible values are:
* :obj:`"no"`: No evaluation is done during training.
* :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
* :obj:`"epoch"`: Evaluation is done at the end of each epoch.
prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`): prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`):
When performing evaluation and predictions, only returns the loss. When performing evaluation and predictions, only returns the loss.
per_device_train_batch_size (:obj:`int`, `optional`, defaults to 8): per_device_train_batch_size (:obj:`int`, `optional`, defaults to 8):
@@ -111,8 +119,9 @@ class TrainingArguments:
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not. or not.
eval_steps (:obj:`int`, `optional`, defaults to 1000): eval_steps (:obj:`int`, `optional`):
Number of update steps between two evaluations. Number of update steps between two evaluations if :obj:`evaluation_strategy="steps"`. Will default to the
same value as :obj:`logging_steps` if not set.
past_index (:obj:`int`, `optional`, defaults to -1): past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the make use of the past hidden states for their predictions. If this argument is set to a positive int, the
@@ -153,7 +162,11 @@ class TrainingArguments:
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluate_during_training: bool = field( evaluate_during_training: bool = field(
default=False, default=None,
metadata={"help": "Run evaluation during training at each logging step."},
)
evaluation_strategy: EvaluationStrategy = field(
default="no",
metadata={"help": "Run evaluation during training at each logging step."}, metadata={"help": "Run evaluation during training at each logging step."},
) )
prediction_loss_only: bool = field( prediction_loss_only: bool = field(
@@ -245,7 +258,7 @@ class TrainingArguments:
dataloader_drop_last: bool = field( dataloader_drop_last: bool = field(
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
) )
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."}) eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
past_index: int = field( past_index: int = field(
default=-1, default=-1,
@@ -269,6 +282,19 @@ class TrainingArguments:
def __post_init__(self): def __post_init__(self):
if self.disable_tqdm is None: if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
if self.evaluate_during_training is not None:
self.evaluation_strategy = (
EvaluationStrategy.STEPS if self.evaluate_during_training else EvaluationStrategy.NO
)
warnings.warn(
"The `evaluate_during_training` argument is deprecated in favor of `evaluation_strategy` (which has more options)",
FutureWarning,
)
else:
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
if self.eval_steps is None:
self.eval_steps = self.logging_steps
@property @property
def train_batch_size(self) -> int: def train_batch_size(self) -> int:
@@ -347,17 +373,27 @@ class TrainingArguments:
""" """
return self._setup_devices[1] return self._setup_devices[1]
def to_dict(self):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
"""
d = dataclasses.asdict(self)
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
return d
def to_json_string(self): def to_json_string(self):
""" """
Serializes this instance to a JSON string. Serializes this instance to a JSON string.
""" """
return json.dumps(dataclasses.asdict(self), indent=2) return json.dumps(self.to_dict(), indent=2)
def to_sanitized_dict(self) -> Dict[str, Any]: def to_sanitized_dict(self) -> Dict[str, Any]:
""" """
Sanitized serialization to use with TensorBoards hparams Sanitized serialization to use with TensorBoards hparams
""" """
d = dataclasses.asdict(self) d = self.to_dict()
d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}} d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}}
valid_types = [bool, int, float, str] valid_types = [bool, int, float, str]