Introduce save_strategy training argument (#10286)

* Introduce save_strategy training argument

* deprecate EvaluationStrategy

* collapse EvaluationStrategy and LoggingStrategy into a single
  IntervalStrategy enum

* modify tests to use modified enum
This commit is contained in:
Tanmay Garg
2021-02-28 06:04:22 +05:30
committed by GitHub
parent aca6288ff4
commit 256482ac92
11 changed files with 81 additions and 46 deletions

View File

@@ -21,7 +21,7 @@ import unittest
import numpy as np
from transformers import AutoTokenizer, EvaluationStrategy, PretrainedConfig, TrainingArguments, is_torch_available
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import (
get_tests_dir,
@@ -852,7 +852,7 @@ class TrainerIntegrationTest(unittest.TestCase):
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
load_best_model_at_end=True,
evaluation_strategy=EvaluationStrategy.EPOCH,
evaluation_strategy=IntervalStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
@@ -867,7 +867,7 @@ class TrainerIntegrationTest(unittest.TestCase):
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
evaluation_strategy=EvaluationStrategy.EPOCH,
evaluation_strategy=IntervalStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
@@ -1013,7 +1013,7 @@ class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase):
output_dir=tmp_dir,
learning_rate=0.1,
logging_steps=1,
evaluation_strategy=EvaluationStrategy.EPOCH,
evaluation_strategy=IntervalStrategy.EPOCH,
num_train_epochs=4,
disable_tqdm=True,
load_best_model_at_end=True,
@@ -1057,7 +1057,7 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
output_dir=tmp_dir,
learning_rate=0.1,
logging_steps=1,
evaluation_strategy=EvaluationStrategy.EPOCH,
evaluation_strategy=IntervalStrategy.EPOCH,
num_train_epochs=4,
disable_tqdm=True,
load_best_model_at_end=True,