[Deepspeed] add many more models to the model zoo test (#12695)
* model zoo take 2 * add deberta * new param for zero2 * doc update * doc update * add layoutlm * bump deepspeed * add deberta-v2, funnel, longformer * new models * style * add t5_v1 * update TAPAS status * reorg problematic models * move doc to another PR * style * fix checkpoint check test * making progress on more models running * cleanup * new version * cleanup
This commit is contained in:
@@ -522,7 +522,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
# see the note above how to get identical loss on a small bs
|
||||
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=2)
|
||||
|
||||
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage):
|
||||
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
|
||||
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
|
||||
|
||||
file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
|
||||
@@ -534,7 +534,8 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
else:
|
||||
raise ValueError(f"unknown stage {stage}")
|
||||
|
||||
ds_file_list.append("zero_pp_rank_0_mp_rank_00_optim_states.pt")
|
||||
if dtype == "bf16":
|
||||
ds_file_list.append("bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt")
|
||||
|
||||
for step in range(freq, total, freq):
|
||||
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
||||
@@ -578,7 +579,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
|
||||
total = int(self.n_epochs * 64 / self.batch_size)
|
||||
self.check_saved_checkpoints_deepspeed(output_dir, freq, total, stage)
|
||||
self.check_saved_checkpoints_deepspeed(output_dir, freq, total, stage, dtype)
|
||||
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_can_resume_training_errors(self, stage, dtype):
|
||||
|
||||
Reference in New Issue
Block a user