Introduce logging_strategy training argument in TrainingArguments and TFTrainingArguments. (#9838)
This commit is contained in:
@@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .trainer_utils import EvaluationStrategy
|
||||
from .trainer_utils import EvaluationStrategy, LoggingStrategy
|
||||
from .training_args import TrainingArguments
|
||||
from .utils import logging
|
||||
|
||||
@@ -403,7 +403,11 @@ class DefaultFlowCallback(TrainerCallback):
|
||||
# Log
|
||||
if state.global_step == 1 and args.logging_first_step:
|
||||
control.should_log = True
|
||||
if args.logging_steps > 0 and state.global_step % args.logging_steps == 0:
|
||||
if (
|
||||
args.logging_strategy == LoggingStrategy.STEPS
|
||||
and args.logging_steps > 0
|
||||
and state.global_step % args.logging_steps == 0
|
||||
):
|
||||
control.should_log = True
|
||||
|
||||
# Evaluate
|
||||
@@ -423,6 +427,11 @@ class DefaultFlowCallback(TrainerCallback):
|
||||
return control
|
||||
|
||||
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
# Log
|
||||
if args.logging_strategy == LoggingStrategy.EPOCH:
|
||||
control.should_log = True
|
||||
|
||||
# Evaluate
|
||||
if args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||
control.should_evaluate = True
|
||||
if args.load_best_model_at_end:
|
||||
|
||||
Reference in New Issue
Block a user