[DeepSpeed] improve checkpoint loading code plus tests (#10760)
* deepspeed checkpoint loading code plus tests * style * style
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
@@ -19,6 +20,8 @@ import sys
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.integrations import is_deepspeed_available
|
||||
from transformers.testing_utils import (
|
||||
CaptureStd,
|
||||
@@ -35,7 +38,7 @@ from transformers.trainer_utils import set_seed
|
||||
|
||||
bindir = os.path.abspath(os.path.dirname(__file__))
|
||||
sys.path.append(f"{bindir}/../../../tests")
|
||||
from test_trainer import get_regression_trainer # noqa
|
||||
from test_trainer import TrainerIntegrationCommon, get_regression_trainer # noqa
|
||||
|
||||
|
||||
set_seed(42)
|
||||
@@ -60,11 +63,21 @@ def require_deepspeed(test_case):
|
||||
|
||||
@require_deepspeed
|
||||
@require_torch_gpu
|
||||
class TrainerIntegrationDeepSpeed(TestCasePlus):
|
||||
""" This class is for testing directly via get_regression_trainer """
|
||||
class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
"""
|
||||
|
||||
This class is for testing directly via get_regression_trainer
|
||||
|
||||
It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods which we can re-use here.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
args = TrainingArguments(".")
|
||||
self.n_epochs = args.num_train_epochs
|
||||
self.batch_size = args.train_batch_size
|
||||
|
||||
self.dist_env_1_gpu = dict(
|
||||
MASTER_ADDR="localhost", MASTER_PORT="10999", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
|
||||
)
|
||||
@@ -222,6 +235,101 @@ class TrainerIntegrationDeepSpeed(TestCasePlus):
|
||||
# see the note above how to get identical loss on a small bs
|
||||
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=5)
|
||||
|
||||
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, is_pretrained=True):
|
||||
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
|
||||
|
||||
file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
|
||||
ds_file_list = ["mp_rank_00_model_states.pt", "zero_pp_rank_0_mp_rank_00optim_states.pt"]
|
||||
|
||||
for step in range(freq, total, freq):
|
||||
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
||||
self.assertTrue(os.path.isdir(checkpoint))
|
||||
|
||||
# common files
|
||||
for filename in file_list:
|
||||
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
|
||||
|
||||
# ds files
|
||||
ds_path = os.path.join(checkpoint, f"global_step{step}")
|
||||
for filename in ds_file_list:
|
||||
# filename = os.path.join(path, filename)
|
||||
# print(filename)
|
||||
self.assertTrue(os.path.isfile(os.path.join(ds_path, filename)))
|
||||
|
||||
def test_save_checkpoints(self):
|
||||
# adapted from TrainerIntegrationTest.test_save_checkpoints
|
||||
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
ds_config_dict = deepcopy(self.ds_config_dict)
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
freq = 5
|
||||
|
||||
# save checkpoints
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=output_dir,
|
||||
save_steps=freq,
|
||||
deepspeed=ds_config_dict,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
total = int(self.n_epochs * 64 / self.batch_size)
|
||||
self.check_saved_checkpoints_deepspeed(output_dir, freq, total)
|
||||
|
||||
def test_can_resume_training(self):
|
||||
# adapted from TrainerIntegrationTest.test_can_resume_training
|
||||
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
ds_config_dict = deepcopy(self.ds_config_dict)
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
kwargs = dict(output_dir=output_dir, train_len=128, save_steps=5, learning_rate=0.1, deepspeed=ds_config_dict)
|
||||
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
trainer.train()
|
||||
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state = dataclasses.asdict(trainer.state)
|
||||
|
||||
checkpoint = os.path.join(output_dir, "checkpoint-5")
|
||||
|
||||
# Reinitialize trainer
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
# Now check with a later checkpoint that it also works when we span over one epoch
|
||||
checkpoint = os.path.join(output_dir, "checkpoint-15")
|
||||
|
||||
# Reinitialize trainer and load model
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
# Now check failures
|
||||
|
||||
# 1. fail to find a bogus checkpoint
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
with self.assertRaises(Exception) as context:
|
||||
trainer.train(resume_from_checkpoint=f"{checkpoint}-bogus")
|
||||
self.assertTrue("failed to resume from checkpoint" in str(context.exception))
|
||||
|
||||
# 2. fail to find any checkpoint - due a fresh output_dir
|
||||
output_dir2 = self.get_auto_remove_tmp_dir()
|
||||
trainer = get_regression_trainer(output_dir=output_dir2, deepspeed=ds_config_dict)
|
||||
with self.assertRaises(Exception) as context:
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
||||
|
||||
|
||||
@slow
|
||||
@require_deepspeed
|
||||
|
||||
Reference in New Issue
Block a user