Introduce save_strategy training argument (#10286)

* Introduce save_strategy training argument

* deprecate EvaluationStrategy

* collapse EvaluationStrategy and LoggingStrategy into a single
  IntervalStrategy enum

* modify tests to use modified enum
This commit is contained in:
Tanmay Garg
2021-02-28 06:04:22 +05:30
committed by GitHub
parent aca6288ff4
commit 256482ac92
11 changed files with 81 additions and 46 deletions

View File

@@ -18,7 +18,7 @@ import unittest
from transformers import (
DefaultFlowCallback,
EvaluationStrategy,
IntervalStrategy,
PrinterCallback,
ProgressCallback,
Trainer,
@@ -129,15 +129,12 @@ class TrainerCallbackTest(unittest.TestCase):
expected_events += ["on_step_begin", "on_step_end"]
if step % trainer.args.logging_steps == 0:
expected_events.append("on_log")
if (
trainer.args.evaluation_strategy == EvaluationStrategy.STEPS
and step % trainer.args.eval_steps == 0
):
if trainer.args.evaluation_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
expected_events += evaluation_events.copy()
if step % trainer.args.save_steps == 0:
expected_events.append("on_save")
expected_events.append("on_epoch_end")
if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH:
if trainer.args.evaluation_strategy == IntervalStrategy.EPOCH:
expected_events += evaluation_events.copy()
expected_events += ["on_log", "on_train_end"]
return expected_events