Introduce save_strategy training argument (#10286)

* Introduce save_strategy training argument

* deprecate EvaluationStrategy

* collapse EvaluationStrategy and LoggingStrategy into a single
  IntervalStrategy enum

* modify tests to use modified enum
This commit is contained in:
Tanmay Garg
2021-02-28 06:04:22 +05:30
committed by GitHub
parent aca6288ff4
commit 256482ac92
11 changed files with 81 additions and 46 deletions

View File

@@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union
import numpy as np
from tqdm.auto import tqdm
from .trainer_utils import EvaluationStrategy, LoggingStrategy
from .trainer_utils import IntervalStrategy
from .training_args import TrainingArguments
from .utils import logging
@@ -404,20 +404,25 @@ class DefaultFlowCallback(TrainerCallback):
if state.global_step == 1 and args.logging_first_step:
control.should_log = True
if (
args.logging_strategy == LoggingStrategy.STEPS
args.logging_strategy == IntervalStrategy.STEPS
and args.logging_steps > 0
and state.global_step % args.logging_steps == 0
):
control.should_log = True
# Evaluate
if args.evaluation_strategy == EvaluationStrategy.STEPS and state.global_step % args.eval_steps == 0:
if args.evaluation_strategy == IntervalStrategy.STEPS and state.global_step % args.eval_steps == 0:
control.should_evaluate = True
if args.load_best_model_at_end:
control.should_save = True
# Save
if not args.load_best_model_at_end and args.save_steps > 0 and state.global_step % args.save_steps == 0:
if (
not args.load_best_model_at_end
and args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True
# End training
@@ -428,14 +433,19 @@ class DefaultFlowCallback(TrainerCallback):
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# Log
if args.logging_strategy == LoggingStrategy.EPOCH:
if args.logging_strategy == IntervalStrategy.EPOCH:
control.should_log = True
# Evaluate
if args.evaluation_strategy == EvaluationStrategy.EPOCH:
if args.evaluation_strategy == IntervalStrategy.EPOCH:
control.should_evaluate = True
if args.load_best_model_at_end:
control.should_save = True
# Save
if args.save_strategy == IntervalStrategy.EPOCH:
control.should_save = True
return control
@@ -531,8 +541,8 @@ class EarlyStoppingCallback(TrainerCallback):
args.metric_for_best_model is not None
), "EarlyStoppingCallback requires metric_for_best_model is defined"
assert (
args.evaluation_strategy != EvaluationStrategy.NO
), "EarlyStoppingCallback requires EvaluationStrategy of steps or epoch"
args.evaluation_strategy != IntervalStrategy.NO
), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch"
def on_evaluate(self, args, state, control, metrics, **kwargs):
metric_to_check = args.metric_for_best_model