[DeepSpeed] improve checkpoint loading code plus tests (#10760)

* deepspeed checkpoint loading code plus tests

* style

* style
This commit is contained in:
Stas Bekman
2021-03-17 10:22:58 -07:00
committed by GitHub
parent 01c7fb04be
commit cd8c93f701
4 changed files with 169 additions and 34 deletions

View File

@@ -24,6 +24,7 @@ import numpy as np
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import (
TestCasePlus,
get_tests_dir,
require_datasets,
require_optuna,
@@ -235,28 +236,7 @@ if is_torch_available():
)
@require_torch
@require_sentencepiece
@require_tokenizers
class TrainerIntegrationTest(unittest.TestCase):
def setUp(self):
args = TrainingArguments(".")
self.n_epochs = args.num_train_epochs
self.batch_size = args.train_batch_size
trainer = get_regression_trainer(learning_rate=0.1)
trainer.train()
self.default_trained_model = (trainer.model.a, trainer.model.b)
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
trainer.train()
self.alternate_trained_model = (trainer.model.a, trainer.model.b)
def check_trained_model(self, model, alternate_seed=False):
# Checks a training seeded with learning_rate = 0.1
(a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model
self.assertTrue(torch.allclose(model.a, a))
self.assertTrue(torch.allclose(model.b, b))
class TrainerIntegrationCommon:
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
file_list = [WEIGHTS_NAME, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
if is_pretrained:
@@ -306,6 +286,30 @@ class TrainerIntegrationTest(unittest.TestCase):
_ = log1.pop("train_samples_per_second", None)
self.assertEqual(log, log1)
@require_torch
@require_sentencepiece
@require_tokenizers
class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
def setUp(self):
super().setUp()
args = TrainingArguments(".")
self.n_epochs = args.num_train_epochs
self.batch_size = args.train_batch_size
trainer = get_regression_trainer(learning_rate=0.1)
trainer.train()
self.default_trained_model = (trainer.model.a, trainer.model.b)
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
trainer.train()
self.alternate_trained_model = (trainer.model.a, trainer.model.b)
def check_trained_model(self, model, alternate_seed=False):
# Checks a training seeded with learning_rate = 0.1
(a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model
self.assertTrue(torch.allclose(model.a, a))
self.assertTrue(torch.allclose(model.b, b))
def test_trainer_works_with_dict(self):
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
# anything.
@@ -607,6 +611,7 @@ class TrainerIntegrationTest(unittest.TestCase):
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
return
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer.train()