[deepspeed] fix --load_best_model_at_end (#14652)
* [deepspeed] fix load_best_model_at_end * try with pull_request_target * revert: try with pull_request_target * style * add test * cleanup
This commit is contained in:
@@ -357,6 +357,18 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
def deepspeed_reinit(trainer):
|
||||
"""
|
||||
this is a temp hack based on: https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until
|
||||
Deepspeed fixes a bug where it can't resume from a checkpoint after it did some stepping
|
||||
https://github.com/microsoft/DeepSpeed/issues/1612
|
||||
"""
|
||||
import deepspeed
|
||||
|
||||
deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**trainer.deepspeed_initialize_kwargs)
|
||||
return deepspeed_engine, optimizer, lr_scheduler
|
||||
|
||||
|
||||
def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False):
|
||||
"""
|
||||
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
|
||||
@@ -398,12 +410,12 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
|
||||
model_parameters = None
|
||||
else:
|
||||
optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
|
||||
|
||||
# keep for quick debug:
|
||||
# from pprint import pprint; pprint(config)
|
||||
|
||||
model, optimizer, _, lr_scheduler = deepspeed.initialize(
|
||||
kwargs = dict(
|
||||
model=model,
|
||||
model_parameters=model_parameters,
|
||||
config_params=config,
|
||||
@@ -411,6 +423,11 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
|
||||
deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
|
||||
|
||||
# stash kwargs to enabled a later deepspeed_reinit
|
||||
trainer.deepspeed_initialize_kwargs = kwargs
|
||||
|
||||
if resume_from_checkpoint is not None:
|
||||
|
||||
# it's possible that the user is trying to resume from model_path, which doesn't necessarily
|
||||
@@ -424,7 +441,7 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
|
||||
if len(deepspeed_checkpoint_dirs) > 0:
|
||||
logger.info(f"Attempting to resume from {resume_from_checkpoint}")
|
||||
# this magically updates self.optimizer and self.lr_scheduler
|
||||
load_path, _ = model.load_checkpoint(
|
||||
load_path, _ = deepspeed_engine.load_checkpoint(
|
||||
resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
|
||||
)
|
||||
if load_path is None:
|
||||
@@ -432,4 +449,4 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
|
||||
else:
|
||||
logger.info(f"{resume_from_checkpoint} doesn't have deepspeed checkpoints, doing nothing")
|
||||
|
||||
return model, optimizer, lr_scheduler
|
||||
return deepspeed_engine, optimizer, lr_scheduler
|
||||
|
||||
@@ -59,7 +59,7 @@ from . import __version__
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
||||
from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
|
||||
from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled
|
||||
from .dependency_versions_check import dep_version_check
|
||||
from .file_utils import (
|
||||
CONFIG_NAME,
|
||||
@@ -1434,6 +1434,18 @@ class Trainer:
|
||||
|
||||
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
|
||||
if os.path.exists(best_model_path):
|
||||
if self.deepspeed:
|
||||
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
|
||||
deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self)
|
||||
self.model = deepspeed_engine.module
|
||||
self.model_wrapped = deepspeed_engine
|
||||
self.deepspeed = deepspeed_engine
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.deepspeed.load_checkpoint(
|
||||
self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
|
||||
)
|
||||
else:
|
||||
# We load the model state dict on the CPU to avoid an OOM error.
|
||||
state_dict = torch.load(best_model_path, map_location="cpu")
|
||||
# If the model is on the GPU, it still works!
|
||||
@@ -1444,11 +1456,6 @@ class Trainer:
|
||||
"on multiple nodes, you should activate `--save_on_each_node`."
|
||||
)
|
||||
|
||||
if self.deepspeed:
|
||||
self.deepspeed.load_checkpoint(
|
||||
self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
|
||||
)
|
||||
|
||||
# add remaining tr_loss
|
||||
self._total_loss_scalar += tr_loss.item()
|
||||
train_loss = self._total_loss_scalar / self.state.global_step
|
||||
@@ -1975,6 +1982,9 @@ class Trainer:
|
||||
# This must be called on all ranks
|
||||
self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME)
|
||||
|
||||
# save a deepspeed checkpoint as well (this is very fast)
|
||||
self.deepspeed.save_checkpoint(output_dir)
|
||||
|
||||
elif self.args.should_save:
|
||||
self._save(output_dir)
|
||||
|
||||
|
||||
@@ -602,6 +602,11 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
# Finally, should be able to resume with the same trainer/same deepspeed engine instance
|
||||
# XXX: but currently this not possible due DS bug: https://github.com/microsoft/DeepSpeed/issues/1612
|
||||
# trainer.train(resume_from_checkpoint=checkpoint)
|
||||
# a workaround needs to be used that re-creates the deepspeed engine
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_load_state_dict_from_zero_checkpoint(self, stage):
|
||||
# test that we can load fp32 weights directly from the zero checkpoint into the current model
|
||||
@@ -968,3 +973,49 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||
with CaptureStderr() as cs:
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
assert "Detected DeepSpeed ZeRO-3" in cs.err
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_load_best_model(self, stage):
|
||||
# this test exercises --load_best_model_at_end - the key is being able to resume after some training
|
||||
|
||||
data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args = f"""
|
||||
--model_name_or_path {T5_TINY}
|
||||
--tokenizer_name {T5_TINY}
|
||||
--train_file {data_dir}/train.json
|
||||
--validation_file {data_dir}/val.json
|
||||
--output_dir {output_dir}
|
||||
--overwrite_output_dir
|
||||
--source_lang en
|
||||
--target_lang ro
|
||||
--do_train
|
||||
--max_train_samples 3
|
||||
--do_eval
|
||||
--max_eval_samples 1
|
||||
--logging_strategy steps
|
||||
--logging_steps 1
|
||||
--evaluation_strategy steps
|
||||
--eval_steps 1
|
||||
--save_strategy steps
|
||||
--save_steps 1
|
||||
--load_best_model_at_end
|
||||
--per_device_train_batch_size 1
|
||||
--per_device_eval_batch_size 1
|
||||
--num_train_epochs 1
|
||||
--fp16
|
||||
--report_to none
|
||||
""".split()
|
||||
args.extend(["--source_prefix", "translate English to Romanian: "])
|
||||
|
||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split()
|
||||
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
|
||||
launcher = get_launcher(distributed=False)
|
||||
|
||||
cmd = launcher + script + args + ds_args
|
||||
# keep for quick debug
|
||||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||
with CaptureStderr() as cs:
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# enough to test it didn't fail
|
||||
assert "Detected DeepSpeed ZeRO-3" in cs.err
|
||||
|
||||
Reference in New Issue
Block a user