Introduce save_strategy training argument (#10286)
* Introduce save_strategy training argument * deprecate EvaluationStrategy * collapse EvaluationStrategy and LoggingStrategy into a single IntervalStrategy enum * modify tests to use modified enum
This commit is contained in:
@@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .trainer_utils import EvaluationStrategy, LoggingStrategy
|
||||
from .trainer_utils import IntervalStrategy
|
||||
from .training_args import TrainingArguments
|
||||
from .utils import logging
|
||||
|
||||
@@ -404,20 +404,25 @@ class DefaultFlowCallback(TrainerCallback):
|
||||
if state.global_step == 1 and args.logging_first_step:
|
||||
control.should_log = True
|
||||
if (
|
||||
args.logging_strategy == LoggingStrategy.STEPS
|
||||
args.logging_strategy == IntervalStrategy.STEPS
|
||||
and args.logging_steps > 0
|
||||
and state.global_step % args.logging_steps == 0
|
||||
):
|
||||
control.should_log = True
|
||||
|
||||
# Evaluate
|
||||
if args.evaluation_strategy == EvaluationStrategy.STEPS and state.global_step % args.eval_steps == 0:
|
||||
if args.evaluation_strategy == IntervalStrategy.STEPS and state.global_step % args.eval_steps == 0:
|
||||
control.should_evaluate = True
|
||||
if args.load_best_model_at_end:
|
||||
control.should_save = True
|
||||
|
||||
# Save
|
||||
if not args.load_best_model_at_end and args.save_steps > 0 and state.global_step % args.save_steps == 0:
|
||||
if (
|
||||
not args.load_best_model_at_end
|
||||
and args.save_strategy == IntervalStrategy.STEPS
|
||||
and args.save_steps > 0
|
||||
and state.global_step % args.save_steps == 0
|
||||
):
|
||||
control.should_save = True
|
||||
|
||||
# End training
|
||||
@@ -428,14 +433,19 @@ class DefaultFlowCallback(TrainerCallback):
|
||||
|
||||
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
# Log
|
||||
if args.logging_strategy == LoggingStrategy.EPOCH:
|
||||
if args.logging_strategy == IntervalStrategy.EPOCH:
|
||||
control.should_log = True
|
||||
|
||||
# Evaluate
|
||||
if args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||
if args.evaluation_strategy == IntervalStrategy.EPOCH:
|
||||
control.should_evaluate = True
|
||||
if args.load_best_model_at_end:
|
||||
control.should_save = True
|
||||
|
||||
# Save
|
||||
if args.save_strategy == IntervalStrategy.EPOCH:
|
||||
control.should_save = True
|
||||
|
||||
return control
|
||||
|
||||
|
||||
@@ -531,8 +541,8 @@ class EarlyStoppingCallback(TrainerCallback):
|
||||
args.metric_for_best_model is not None
|
||||
), "EarlyStoppingCallback requires metric_for_best_model is defined"
|
||||
assert (
|
||||
args.evaluation_strategy != EvaluationStrategy.NO
|
||||
), "EarlyStoppingCallback requires EvaluationStrategy of steps or epoch"
|
||||
args.evaluation_strategy != IntervalStrategy.NO
|
||||
), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch"
|
||||
|
||||
def on_evaluate(self, args, state, control, metrics, **kwargs):
|
||||
metric_to_check = args.metric_for_best_model
|
||||
|
||||
Reference in New Issue
Block a user