Add possibility to evaluate every epoch (#7302)
* Add possibility to evaluate every epoch * Remove multitype arg * Remove needless import * Use a proper enum * Apply suggestions from @LysandreJik Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * One else and formatting Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -39,6 +39,7 @@ from .trainer_utils import (
|
|||||||
PREFIX_CHECKPOINT_DIR,
|
PREFIX_CHECKPOINT_DIR,
|
||||||
BestRun,
|
BestRun,
|
||||||
EvalPrediction,
|
EvalPrediction,
|
||||||
|
EvaluationStrategy,
|
||||||
HPSearchBackend,
|
HPSearchBackend,
|
||||||
PredictionOutput,
|
PredictionOutput,
|
||||||
TrainOutput,
|
TrainOutput,
|
||||||
@@ -782,7 +783,10 @@ class Trainer:
|
|||||||
|
|
||||||
self.log(logs)
|
self.log(logs)
|
||||||
|
|
||||||
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
|
if (
|
||||||
|
self.args.evaluation_strategy == EvaluationStrategy.STEPS
|
||||||
|
and self.global_step % self.args.eval_steps == 0
|
||||||
|
):
|
||||||
metrics = self.evaluate()
|
metrics = self.evaluate()
|
||||||
self._report_to_hp_search(trial, epoch, metrics)
|
self._report_to_hp_search(trial, epoch, metrics)
|
||||||
|
|
||||||
@@ -820,6 +824,9 @@ class Trainer:
|
|||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
|
|
||||||
epoch_pbar.update(1)
|
epoch_pbar.update(1)
|
||||||
|
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||||
|
metrics = self.evaluate()
|
||||||
|
self._report_to_hp_search(trial, epoch, metrics)
|
||||||
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
||||||
break
|
break
|
||||||
epoch_pbar.close()
|
epoch_pbar.close()
|
||||||
|
|||||||
@@ -60,6 +60,12 @@ class TrainOutput(NamedTuple):
|
|||||||
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationStrategy(ExplicitEnum):
|
||||||
|
NO = "no"
|
||||||
|
STEPS = "steps"
|
||||||
|
EPOCH = "epoch"
|
||||||
|
|
||||||
|
|
||||||
class BestRun(NamedTuple):
|
class BestRun(NamedTuple):
|
||||||
"""
|
"""
|
||||||
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
|
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
||||||
|
from .trainer_utils import EvaluationStrategy
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -50,8 +53,13 @@ class TrainingArguments:
|
|||||||
Whether to run evaluation on the dev set or not.
|
Whether to run evaluation on the dev set or not.
|
||||||
do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to run predictions on the test set or not.
|
Whether to run predictions on the test set or not.
|
||||||
evaluate_during_training (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
evaluation_strategy(:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
|
||||||
Whether to run evaluation during training at each logging step or not.
|
The evaluation strategy to adopt during training. Possible values are:
|
||||||
|
|
||||||
|
* :obj:`"no"`: No evaluation is done during training.
|
||||||
|
* :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
|
||||||
|
* :obj:`"epoch"`: Evaluation is done at the end of each epoch.
|
||||||
|
|
||||||
prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`):
|
prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`):
|
||||||
When performing evaluation and predictions, only returns the loss.
|
When performing evaluation and predictions, only returns the loss.
|
||||||
per_device_train_batch_size (:obj:`int`, `optional`, defaults to 8):
|
per_device_train_batch_size (:obj:`int`, `optional`, defaults to 8):
|
||||||
@@ -111,8 +119,9 @@ class TrainingArguments:
|
|||||||
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
|
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
|
||||||
or not.
|
or not.
|
||||||
eval_steps (:obj:`int`, `optional`, defaults to 1000):
|
eval_steps (:obj:`int`, `optional`):
|
||||||
Number of update steps between two evaluations.
|
Number of update steps between two evaluations if :obj:`evaluation_strategy="steps"`. Will default to the
|
||||||
|
same value as :obj:`logging_steps` if not set.
|
||||||
past_index (:obj:`int`, `optional`, defaults to -1):
|
past_index (:obj:`int`, `optional`, defaults to -1):
|
||||||
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
|
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
|
||||||
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
|
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
|
||||||
@@ -153,7 +162,11 @@ class TrainingArguments:
|
|||||||
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
||||||
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
||||||
evaluate_during_training: bool = field(
|
evaluate_during_training: bool = field(
|
||||||
default=False,
|
default=None,
|
||||||
|
metadata={"help": "Run evaluation during training at each logging step."},
|
||||||
|
)
|
||||||
|
evaluation_strategy: EvaluationStrategy = field(
|
||||||
|
default="no",
|
||||||
metadata={"help": "Run evaluation during training at each logging step."},
|
metadata={"help": "Run evaluation during training at each logging step."},
|
||||||
)
|
)
|
||||||
prediction_loss_only: bool = field(
|
prediction_loss_only: bool = field(
|
||||||
@@ -245,7 +258,7 @@ class TrainingArguments:
|
|||||||
dataloader_drop_last: bool = field(
|
dataloader_drop_last: bool = field(
|
||||||
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
|
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
|
||||||
)
|
)
|
||||||
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
|
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
||||||
|
|
||||||
past_index: int = field(
|
past_index: int = field(
|
||||||
default=-1,
|
default=-1,
|
||||||
@@ -269,6 +282,19 @@ class TrainingArguments:
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.disable_tqdm is None:
|
if self.disable_tqdm is None:
|
||||||
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||||
|
if self.evaluate_during_training is not None:
|
||||||
|
self.evaluation_strategy = (
|
||||||
|
EvaluationStrategy.STEPS if self.evaluate_during_training else EvaluationStrategy.NO
|
||||||
|
)
|
||||||
|
warnings.warn(
|
||||||
|
"The `evaluate_during_training` argument is deprecated in favor of `evaluation_strategy` (which has more options)",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
|
||||||
|
|
||||||
|
if self.eval_steps is None:
|
||||||
|
self.eval_steps = self.logging_steps
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def train_batch_size(self) -> int:
|
def train_batch_size(self) -> int:
|
||||||
@@ -347,17 +373,27 @@ class TrainingArguments:
|
|||||||
"""
|
"""
|
||||||
return self._setup_devices[1]
|
return self._setup_devices[1]
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""
|
||||||
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
|
||||||
|
"""
|
||||||
|
d = dataclasses.asdict(self)
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, Enum):
|
||||||
|
d[k] = v.value
|
||||||
|
return d
|
||||||
|
|
||||||
def to_json_string(self):
|
def to_json_string(self):
|
||||||
"""
|
"""
|
||||||
Serializes this instance to a JSON string.
|
Serializes this instance to a JSON string.
|
||||||
"""
|
"""
|
||||||
return json.dumps(dataclasses.asdict(self), indent=2)
|
return json.dumps(self.to_dict(), indent=2)
|
||||||
|
|
||||||
def to_sanitized_dict(self) -> Dict[str, Any]:
|
def to_sanitized_dict(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Sanitized serialization to use with TensorBoard’s hparams
|
Sanitized serialization to use with TensorBoard’s hparams
|
||||||
"""
|
"""
|
||||||
d = dataclasses.asdict(self)
|
d = self.to_dict()
|
||||||
d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}}
|
d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}}
|
||||||
|
|
||||||
valid_types = [bool, int, float, str]
|
valid_types = [bool, int, float, str]
|
||||||
|
|||||||
Reference in New Issue
Block a user