Introduce save_strategy training argument (#10286)
* Introduce save_strategy training argument * deprecate EvaluationStrategy * collapse EvaluationStrategy and LoggingStrategy into a single IntervalStrategy enum * modify tests to use modified enum
This commit is contained in:
@@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
from transformers import (
|
||||
DefaultFlowCallback,
|
||||
EvaluationStrategy,
|
||||
IntervalStrategy,
|
||||
PrinterCallback,
|
||||
ProgressCallback,
|
||||
Trainer,
|
||||
@@ -129,15 +129,12 @@ class TrainerCallbackTest(unittest.TestCase):
|
||||
expected_events += ["on_step_begin", "on_step_end"]
|
||||
if step % trainer.args.logging_steps == 0:
|
||||
expected_events.append("on_log")
|
||||
if (
|
||||
trainer.args.evaluation_strategy == EvaluationStrategy.STEPS
|
||||
and step % trainer.args.eval_steps == 0
|
||||
):
|
||||
if trainer.args.evaluation_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
|
||||
expected_events += evaluation_events.copy()
|
||||
if step % trainer.args.save_steps == 0:
|
||||
expected_events.append("on_save")
|
||||
expected_events.append("on_epoch_end")
|
||||
if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||
if trainer.args.evaluation_strategy == IntervalStrategy.EPOCH:
|
||||
expected_events += evaluation_events.copy()
|
||||
expected_events += ["on_log", "on_train_end"]
|
||||
return expected_events
|
||||
|
||||
Reference in New Issue
Block a user