Introduce logging_strategy training argument in TrainingArguments and TFTrainingArguments. (#9838)
This commit is contained in:
@@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from .trainer_utils import EvaluationStrategy
|
from .trainer_utils import EvaluationStrategy, LoggingStrategy
|
||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
@@ -403,7 +403,11 @@ class DefaultFlowCallback(TrainerCallback):
|
|||||||
# Log
|
# Log
|
||||||
if state.global_step == 1 and args.logging_first_step:
|
if state.global_step == 1 and args.logging_first_step:
|
||||||
control.should_log = True
|
control.should_log = True
|
||||||
if args.logging_steps > 0 and state.global_step % args.logging_steps == 0:
|
if (
|
||||||
|
args.logging_strategy == LoggingStrategy.STEPS
|
||||||
|
and args.logging_steps > 0
|
||||||
|
and state.global_step % args.logging_steps == 0
|
||||||
|
):
|
||||||
control.should_log = True
|
control.should_log = True
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
@@ -423,6 +427,11 @@ class DefaultFlowCallback(TrainerCallback):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
# Log
|
||||||
|
if args.logging_strategy == LoggingStrategy.EPOCH:
|
||||||
|
control.should_log = True
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
if args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
if args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||||
control.should_evaluate = True
|
control.should_evaluate = True
|
||||||
if args.load_best_model_at_end:
|
if args.load_best_model_at_end:
|
||||||
|
|||||||
@@ -107,6 +107,12 @@ class EvaluationStrategy(ExplicitEnum):
|
|||||||
EPOCH = "epoch"
|
EPOCH = "epoch"
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingStrategy(ExplicitEnum):
|
||||||
|
NO = "no"
|
||||||
|
STEPS = "steps"
|
||||||
|
EPOCH = "epoch"
|
||||||
|
|
||||||
|
|
||||||
class BestRun(NamedTuple):
|
class BestRun(NamedTuple):
|
||||||
"""
|
"""
|
||||||
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
|
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from .file_utils import (
|
|||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
torch_required,
|
torch_required,
|
||||||
)
|
)
|
||||||
from .trainer_utils import EvaluationStrategy, SchedulerType
|
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -139,10 +139,17 @@ class TrainingArguments:
|
|||||||
logging_dir (:obj:`str`, `optional`):
|
logging_dir (:obj:`str`, `optional`):
|
||||||
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
||||||
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
||||||
|
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||||
|
The logging strategy to adopt during training. Possible values are:
|
||||||
|
|
||||||
|
* :obj:`"no"`: No logging is done during training.
|
||||||
|
* :obj:`"epoch"`: Logging is done at the end of each epoch.
|
||||||
|
* :obj:`"steps"`: Logging is done every :obj:`logging_steps`.
|
||||||
|
|
||||||
logging_first_step (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
logging_first_step (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to log and evaluate the first :obj:`global_step` or not.
|
Whether to log and evaluate the first :obj:`global_step` or not.
|
||||||
logging_steps (:obj:`int`, `optional`, defaults to 500):
|
logging_steps (:obj:`int`, `optional`, defaults to 500):
|
||||||
Number of update steps between two logs.
|
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
|
||||||
save_steps (:obj:`int`, `optional`, defaults to 500):
|
save_steps (:obj:`int`, `optional`, defaults to 500):
|
||||||
Number of updates steps before two checkpoint saves.
|
Number of updates steps before two checkpoint saves.
|
||||||
save_total_limit (:obj:`int`, `optional`):
|
save_total_limit (:obj:`int`, `optional`):
|
||||||
@@ -339,6 +346,10 @@ class TrainingArguments:
|
|||||||
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||||
|
|
||||||
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
||||||
|
logging_strategy: LoggingStrategy = field(
|
||||||
|
default="steps",
|
||||||
|
metadata={"help": "The logging strategy to use."},
|
||||||
|
)
|
||||||
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
|
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
|
||||||
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||||
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||||
@@ -482,6 +493,7 @@ class TrainingArguments:
|
|||||||
if self.disable_tqdm is None:
|
if self.disable_tqdm is None:
|
||||||
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||||
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
|
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
|
||||||
|
self.logging_strategy = LoggingStrategy(self.logging_strategy)
|
||||||
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
|
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
|
||||||
if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO:
|
if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO:
|
||||||
self.do_eval = True
|
self.do_eval = True
|
||||||
|
|||||||
@@ -102,10 +102,17 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
logging_dir (:obj:`str`, `optional`):
|
logging_dir (:obj:`str`, `optional`):
|
||||||
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
||||||
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
||||||
|
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||||
|
The logging strategy to adopt during training. Possible values are:
|
||||||
|
|
||||||
|
* :obj:`"no"`: No logging is done during training.
|
||||||
|
* :obj:`"epoch"`: Logging is done at the end of each epoch.
|
||||||
|
* :obj:`"steps"`: Logging is done every :obj:`logging_steps`.
|
||||||
|
|
||||||
logging_first_step (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
logging_first_step (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to log and evaluate the first :obj:`global_step` or not.
|
Whether to log and evaluate the first :obj:`global_step` or not.
|
||||||
logging_steps (:obj:`int`, `optional`, defaults to 500):
|
logging_steps (:obj:`int`, `optional`, defaults to 500):
|
||||||
Number of update steps between two logs.
|
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
|
||||||
save_steps (:obj:`int`, `optional`, defaults to 500):
|
save_steps (:obj:`int`, `optional`, defaults to 500):
|
||||||
Number of updates steps before two checkpoint saves.
|
Number of updates steps before two checkpoint saves.
|
||||||
save_total_limit (:obj:`int`, `optional`):
|
save_total_limit (:obj:`int`, `optional`):
|
||||||
|
|||||||
Reference in New Issue
Block a user