From 709c86b5a925f1efe650e24ee8b1f52bdc5a3acb Mon Sep 17 00:00:00 2001 From: Tanmay Garg Date: Fri, 19 Feb 2021 22:19:22 +0530 Subject: [PATCH] Introduce logging_strategy training argument (#10267) (#10267) Introduce logging_strategy training argument in TrainingArguments and TFTrainingArguments. (#9838) --- src/transformers/trainer_callback.py | 13 +++++++++++-- src/transformers/trainer_utils.py | 6 ++++++ src/transformers/training_args.py | 16 ++++++++++++++-- src/transformers/training_args_tf.py | 9 ++++++++- 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 34027dc9e1..b16f70921d 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union import numpy as np from tqdm.auto import tqdm -from .trainer_utils import EvaluationStrategy +from .trainer_utils import EvaluationStrategy, LoggingStrategy from .training_args import TrainingArguments from .utils import logging @@ -403,7 +403,11 @@ class DefaultFlowCallback(TrainerCallback): # Log if state.global_step == 1 and args.logging_first_step: control.should_log = True - if args.logging_steps > 0 and state.global_step % args.logging_steps == 0: + if ( + args.logging_strategy == LoggingStrategy.STEPS + and args.logging_steps > 0 + and state.global_step % args.logging_steps == 0 + ): control.should_log = True # Evaluate @@ -423,6 +427,11 @@ class DefaultFlowCallback(TrainerCallback): return control def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + # Log + if args.logging_strategy == LoggingStrategy.EPOCH: + control.should_log = True + + # Evaluate if args.evaluation_strategy == EvaluationStrategy.EPOCH: control.should_evaluate = True if args.load_best_model_at_end: diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 760a92d5bd..5e0f1ae948 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -107,6 +107,12 @@ class EvaluationStrategy(ExplicitEnum): EPOCH = "epoch" +class LoggingStrategy(ExplicitEnum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + + class BestRun(NamedTuple): """ The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`). diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 05c05bb6a3..22504aa108 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -25,7 +25,7 @@ from .file_utils import ( is_torch_tpu_available, torch_required, ) -from .trainer_utils import EvaluationStrategy, SchedulerType +from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType from .utils import logging @@ -139,10 +139,17 @@ class TrainingArguments: logging_dir (:obj:`str`, `optional`): `TensorBoard `__ log directory. Will default to `runs/**CURRENT_DATETIME_HOSTNAME**`. + logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`): + The logging strategy to adopt during training. Possible values are: + + * :obj:`"no"`: No logging is done during training. + * :obj:`"epoch"`: Logging is done at the end of each epoch. + * :obj:`"steps"`: Logging is done every :obj:`logging_steps`. + logging_first_step (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to log and evaluate the first :obj:`global_step` or not. logging_steps (:obj:`int`, `optional`, defaults to 500): - Number of update steps between two logs. + Number of update steps between two logs if :obj:`logging_strategy="steps"`. save_steps (:obj:`int`, `optional`, defaults to 500): Number of updates steps before two checkpoint saves. save_total_limit (:obj:`int`, `optional`): @@ -339,6 +346,10 @@ class TrainingArguments: warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."}) + logging_strategy: LoggingStrategy = field( + default="steps", + metadata={"help": "The logging strategy to use."}, + ) logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"}) logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) @@ -482,6 +493,7 @@ class TrainingArguments: if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy) + self.logging_strategy = LoggingStrategy(self.logging_strategy) self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO: self.do_eval = True diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py index 8215a0122a..2b66d44487 100644 --- a/src/transformers/training_args_tf.py +++ b/src/transformers/training_args_tf.py @@ -102,10 +102,17 @@ class TFTrainingArguments(TrainingArguments): logging_dir (:obj:`str`, `optional`): `TensorBoard `__ log directory. Will default to `runs/**CURRENT_DATETIME_HOSTNAME**`. + logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`): + The logging strategy to adopt during training. Possible values are: + + * :obj:`"no"`: No logging is done during training. + * :obj:`"epoch"`: Logging is done at the end of each epoch. + * :obj:`"steps"`: Logging is done every :obj:`logging_steps`. + logging_first_step (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to log and evaluate the first :obj:`global_step` or not. logging_steps (:obj:`int`, `optional`, defaults to 500): - Number of update steps between two logs. + Number of update steps between two logs if :obj:`logging_strategy="steps"`. save_steps (:obj:`int`, `optional`, defaults to 500): Number of updates steps before two checkpoint saves. save_total_limit (:obj:`int`, `optional`):