From b66c5ab20c8bb08d52cb840382498f936ea8da03 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 6 Dec 2021 21:57:47 -0800 Subject: [PATCH] [deepspeed] fix --load_best_model_at_end (#14652) * [deepspeed] fix load_best_model_at_end * try with pull_request_target * revert: try with pull_request_target * style * add test * cleanup --- src/transformers/deepspeed.py | 25 ++++++++++++--- src/transformers/trainer.py | 30 ++++++++++++------ tests/deepspeed/test_deepspeed.py | 51 +++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 14 deletions(-) diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py index edbcbd50cc..d1d2114c89 100644 --- a/src/transformers/deepspeed.py +++ b/src/transformers/deepspeed.py @@ -357,6 +357,18 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps return optimizer, lr_scheduler +def deepspeed_reinit(trainer): + """ + this is a temp hack based on: https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until + Deepspeed fixes a bug where it can't resume from a checkpoint after it did some stepping + https://github.com/microsoft/DeepSpeed/issues/1612 + """ + import deepspeed + + deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**trainer.deepspeed_initialize_kwargs) + return deepspeed_engine, optimizer, lr_scheduler + + def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False): """ Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. @@ -398,12 +410,12 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf model_parameters = None else: optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps) - model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + model_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) # keep for quick debug: # from pprint import pprint; pprint(config) - model, optimizer, _, lr_scheduler = deepspeed.initialize( + kwargs = dict( model=model, model_parameters=model_parameters, config_params=config, @@ -411,6 +423,11 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf lr_scheduler=lr_scheduler, ) + deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) + + # stash kwargs to enabled a later deepspeed_reinit + trainer.deepspeed_initialize_kwargs = kwargs + if resume_from_checkpoint is not None: # it's possible that the user is trying to resume from model_path, which doesn't necessarily @@ -424,7 +441,7 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf if len(deepspeed_checkpoint_dirs) > 0: logger.info(f"Attempting to resume from {resume_from_checkpoint}") # this magically updates self.optimizer and self.lr_scheduler - load_path, _ = model.load_checkpoint( + load_path, _ = deepspeed_engine.load_checkpoint( resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True ) if load_path is None: @@ -432,4 +449,4 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf else: logger.info(f"{resume_from_checkpoint} doesn't have deepspeed checkpoints, doing nothing") - return model, optimizer, lr_scheduler + return deepspeed_engine, optimizer, lr_scheduler diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 381bb0244a..107e06839f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -59,7 +59,7 @@ from . import __version__ from .configuration_utils import PretrainedConfig from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .debug_utils import DebugOption, DebugUnderflowOverflow -from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled +from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled from .dependency_versions_check import dep_version_check from .file_utils import ( CONFIG_NAME, @@ -1434,21 +1434,28 @@ class Trainer: best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) if os.path.exists(best_model_path): - # We load the model state dict on the CPU to avoid an OOM error. - state_dict = torch.load(best_model_path, map_location="cpu") - # If the model is on the GPU, it still works! - self._load_state_dict_in_model(state_dict) + if self.deepspeed: + # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping + deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.deepspeed.load_checkpoint( + self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True + ) + else: + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = torch.load(best_model_path, map_location="cpu") + # If the model is on the GPU, it still works! + self._load_state_dict_in_model(state_dict) else: logger.warn( f"Could not locate the best model at {best_model_path}, if you are running a distributed training " "on multiple nodes, you should activate `--save_on_each_node`." ) - if self.deepspeed: - self.deepspeed.load_checkpoint( - self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False - ) - # add remaining tr_loss self._total_loss_scalar += tr_loss.item() train_loss = self._total_loss_scalar / self.state.global_step @@ -1975,6 +1982,9 @@ class Trainer: # This must be called on all ranks self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME) + # save a deepspeed checkpoint as well (this is very fast) + self.deepspeed.save_checkpoint(output_dir) + elif self.args.should_save: self._save(output_dir) diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 8e7587235d..8aaf789b97 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -602,6 +602,11 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): self.assertEqual(b, b1) self.check_trainer_state_are_the_same(state, state1) + # Finally, should be able to resume with the same trainer/same deepspeed engine instance + # XXX: but currently this not possible due DS bug: https://github.com/microsoft/DeepSpeed/issues/1612 + # trainer.train(resume_from_checkpoint=checkpoint) + # a workaround needs to be used that re-creates the deepspeed engine + @parameterized.expand(stages) def test_load_state_dict_from_zero_checkpoint(self, stage): # test that we can load fp32 weights directly from the zero checkpoint into the current model @@ -968,3 +973,49 @@ class TestDeepSpeedWithLauncher(TestCasePlus): with CaptureStderr() as cs: execute_subprocess_async(cmd, env=self.get_env()) assert "Detected DeepSpeed ZeRO-3" in cs.err + + @parameterized.expand(stages) + def test_load_best_model(self, stage): + # this test exercises --load_best_model_at_end - the key is being able to resume after some training + + data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro" + output_dir = self.get_auto_remove_tmp_dir() + args = f""" + --model_name_or_path {T5_TINY} + --tokenizer_name {T5_TINY} + --train_file {data_dir}/train.json + --validation_file {data_dir}/val.json + --output_dir {output_dir} + --overwrite_output_dir + --source_lang en + --target_lang ro + --do_train + --max_train_samples 3 + --do_eval + --max_eval_samples 1 + --logging_strategy steps + --logging_steps 1 + --evaluation_strategy steps + --eval_steps 1 + --save_strategy steps + --save_steps 1 + --load_best_model_at_end + --per_device_train_batch_size 1 + --per_device_eval_batch_size 1 + --num_train_epochs 1 + --fp16 + --report_to none + """.split() + args.extend(["--source_prefix", "translate English to Romanian: "]) + + ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split() + script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"] + launcher = get_launcher(distributed=False) + + cmd = launcher + script + args + ds_args + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + with CaptureStderr() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + # enough to test it didn't fail + assert "Detected DeepSpeed ZeRO-3" in cs.err