[deepspeed] fix --load_best_model_at_end (#14652)

* [deepspeed] fix load_best_model_at_end

* try with pull_request_target

* revert: try with pull_request_target

* style

* add test

* cleanup
This commit is contained in:
Stas Bekman
2021-12-06 21:57:47 -08:00
committed by GitHub
parent 30646a0a3c
commit b66c5ab20c
3 changed files with 92 additions and 14 deletions

View File

@@ -602,6 +602,11 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
# Finally, should be able to resume with the same trainer/same deepspeed engine instance
# XXX: but currently this not possible due DS bug: https://github.com/microsoft/DeepSpeed/issues/1612
# trainer.train(resume_from_checkpoint=checkpoint)
# a workaround needs to be used that re-creates the deepspeed engine
@parameterized.expand(stages)
def test_load_state_dict_from_zero_checkpoint(self, stage):
# test that we can load fp32 weights directly from the zero checkpoint into the current model
@@ -968,3 +973,49 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env())
assert "Detected DeepSpeed ZeRO-3" in cs.err
@parameterized.expand(stages)
def test_load_best_model(self, stage):
# this test exercises --load_best_model_at_end - the key is being able to resume after some training
data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
args = f"""
--model_name_or_path {T5_TINY}
--tokenizer_name {T5_TINY}
--train_file {data_dir}/train.json
--validation_file {data_dir}/val.json
--output_dir {output_dir}
--overwrite_output_dir
--source_lang en
--target_lang ro
--do_train
--max_train_samples 3
--do_eval
--max_eval_samples 1
--logging_strategy steps
--logging_steps 1
--evaluation_strategy steps
--eval_steps 1
--save_strategy steps
--save_steps 1
--load_best_model_at_end
--per_device_train_batch_size 1
--per_device_eval_batch_size 1
--num_train_epochs 1
--fp16
--report_to none
""".split()
args.extend(["--source_prefix", "translate English to Romanian: "])
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split()
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
launcher = get_launcher(distributed=False)
cmd = launcher + script + args + ds_args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env())
# enough to test it didn't fail
assert "Detected DeepSpeed ZeRO-3" in cs.err