[DeepSpeed] improve checkpoint loading code plus tests (#10760)
* deepspeed checkpoint loading code plus tests * style * style
This commit is contained in:
@@ -24,6 +24,7 @@ import numpy as np
|
||||
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
get_tests_dir,
|
||||
require_datasets,
|
||||
require_optuna,
|
||||
@@ -235,28 +236,7 @@ if is_torch_available():
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TrainerIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
args = TrainingArguments(".")
|
||||
self.n_epochs = args.num_train_epochs
|
||||
self.batch_size = args.train_batch_size
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
trainer.train()
|
||||
self.default_trained_model = (trainer.model.a, trainer.model.b)
|
||||
|
||||
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
|
||||
trainer.train()
|
||||
self.alternate_trained_model = (trainer.model.a, trainer.model.b)
|
||||
|
||||
def check_trained_model(self, model, alternate_seed=False):
|
||||
# Checks a training seeded with learning_rate = 0.1
|
||||
(a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model
|
||||
self.assertTrue(torch.allclose(model.a, a))
|
||||
self.assertTrue(torch.allclose(model.b, b))
|
||||
|
||||
class TrainerIntegrationCommon:
|
||||
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
|
||||
file_list = [WEIGHTS_NAME, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
|
||||
if is_pretrained:
|
||||
@@ -306,6 +286,30 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
_ = log1.pop("train_samples_per_second", None)
|
||||
self.assertEqual(log, log1)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
args = TrainingArguments(".")
|
||||
self.n_epochs = args.num_train_epochs
|
||||
self.batch_size = args.train_batch_size
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
trainer.train()
|
||||
self.default_trained_model = (trainer.model.a, trainer.model.b)
|
||||
|
||||
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
|
||||
trainer.train()
|
||||
self.alternate_trained_model = (trainer.model.a, trainer.model.b)
|
||||
|
||||
def check_trained_model(self, model, alternate_seed=False):
|
||||
# Checks a training seeded with learning_rate = 0.1
|
||||
(a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model
|
||||
self.assertTrue(torch.allclose(model.a, a))
|
||||
self.assertTrue(torch.allclose(model.b, b))
|
||||
|
||||
def test_trainer_works_with_dict(self):
|
||||
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
|
||||
# anything.
|
||||
@@ -607,6 +611,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
# won't be the same since the training dataloader is shuffled).
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
trainer.train()
|
||||
|
||||
Reference in New Issue
Block a user