[deepspeed] fix --load_best_model_at_end (#14652)
* [deepspeed] fix load_best_model_at_end * try with pull_request_target * revert: try with pull_request_target * style * add test * cleanup
This commit is contained in:
@@ -602,6 +602,11 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
# Finally, should be able to resume with the same trainer/same deepspeed engine instance
|
||||
# XXX: but currently this not possible due DS bug: https://github.com/microsoft/DeepSpeed/issues/1612
|
||||
# trainer.train(resume_from_checkpoint=checkpoint)
|
||||
# a workaround needs to be used that re-creates the deepspeed engine
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_load_state_dict_from_zero_checkpoint(self, stage):
|
||||
# test that we can load fp32 weights directly from the zero checkpoint into the current model
|
||||
@@ -968,3 +973,49 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||
with CaptureStderr() as cs:
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
assert "Detected DeepSpeed ZeRO-3" in cs.err
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_load_best_model(self, stage):
|
||||
# this test exercises --load_best_model_at_end - the key is being able to resume after some training
|
||||
|
||||
data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args = f"""
|
||||
--model_name_or_path {T5_TINY}
|
||||
--tokenizer_name {T5_TINY}
|
||||
--train_file {data_dir}/train.json
|
||||
--validation_file {data_dir}/val.json
|
||||
--output_dir {output_dir}
|
||||
--overwrite_output_dir
|
||||
--source_lang en
|
||||
--target_lang ro
|
||||
--do_train
|
||||
--max_train_samples 3
|
||||
--do_eval
|
||||
--max_eval_samples 1
|
||||
--logging_strategy steps
|
||||
--logging_steps 1
|
||||
--evaluation_strategy steps
|
||||
--eval_steps 1
|
||||
--save_strategy steps
|
||||
--save_steps 1
|
||||
--load_best_model_at_end
|
||||
--per_device_train_batch_size 1
|
||||
--per_device_eval_batch_size 1
|
||||
--num_train_epochs 1
|
||||
--fp16
|
||||
--report_to none
|
||||
""".split()
|
||||
args.extend(["--source_prefix", "translate English to Romanian: "])
|
||||
|
||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split()
|
||||
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
|
||||
launcher = get_launcher(distributed=False)
|
||||
|
||||
cmd = launcher + script + args + ds_args
|
||||
# keep for quick debug
|
||||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||
with CaptureStderr() as cs:
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# enough to test it didn't fail
|
||||
assert "Detected DeepSpeed ZeRO-3" in cs.err
|
||||
|
||||
Reference in New Issue
Block a user