deepspeed resume from ckpt fixes and adding support for deepspeed optimizer and HF scheduler (#25863)

* Add support for deepspeed optimizer and HF scheduler

* fix bug

* fix the import

* fix issue with deepspeed scheduler saving for hf optim + hf scheduler scenario

* fix loading of hf scheduler when loading deepspeed checkpoint

* fix import of `DeepSpeedSchedulerWrapper`

* add tests

* add the comment and skip the failing tests

* address comment
This commit is contained in:
Sourab Mangrulkar
2023-09-05 22:31:20 +05:30
committed by GitHub
parent 1110b565d6
commit 6bc517ccd4
3 changed files with 55 additions and 10 deletions

View File

@@ -26,6 +26,8 @@ from ..utils import is_accelerate_available, is_torch_available, logging
if is_torch_available():
import torch
from ..optimization import get_scheduler
logger = logging.get_logger(__name__)
@@ -274,7 +276,7 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
# 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Mostly*
# 3. DS scheduler + HF optimizer: Mostly*
# 4. HF scheduler + DS optimizer: No
# 4. HF scheduler + DS optimizer: Yes
#
# Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)
@@ -304,11 +306,18 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
lr_scheduler = DummyScheduler(optimizer)
else:
if isinstance(optimizer, DummyOptim):
raise ValueError(
"Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. "
"Please configure a scheduler in the DeepSpeed config."
)
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
def _lr_scheduler_callable(optimizer):
return get_scheduler(
trainer.args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=trainer.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
else:
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
return optimizer, lr_scheduler

View File

@@ -60,7 +60,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .dependency_versions_check import dep_version_check
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
@@ -212,6 +212,9 @@ if is_accelerate_available():
save_fsdp_optimizer,
)
if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
if TYPE_CHECKING:
import optuna
@@ -2362,7 +2365,14 @@ class Trainer:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
# Save SCHEDULER & SCALER
if self.args.should_save and not self.is_deepspeed_enabled and not is_torch_tpu_available():
is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
self.lr_scheduler, DeepSpeedSchedulerWrapper
)
if (
self.args.should_save
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
and not is_torch_tpu_available()
):
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
@@ -2428,6 +2438,10 @@ class Trainer:
if self.is_deepspeed_enabled:
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
return
checkpoint_file_exists = (

View File

@@ -136,6 +136,14 @@ ZERO3 = "zero3"
FP16 = "fp16"
BF16 = "bf16"
HF_OPTIM = "hf_optim"
HF_SCHEDULER = "hf_scheduler"
DS_OPTIM = "ds_optim"
DS_SCHEDULER = "ds_scheduler"
optims = [HF_OPTIM, DS_OPTIM]
schedulers = [HF_SCHEDULER, DS_SCHEDULER]
stages = [ZERO2, ZERO3]
if is_torch_bf16_gpu_available():
dtypes = [FP16, BF16]
@@ -153,6 +161,8 @@ def parameterized_custom_name_func(func, param_num, param):
# Cartesian-product of zero stages with models to test
params = list(itertools.product(stages, dtypes))
params_with_optims_and_schedulers = list(itertools.product(stages, dtypes, optims, schedulers))
@require_deepspeed
@require_torch_gpu
@@ -640,10 +650,16 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
"Can't find a valid checkpoint at" in str(context.exception), f"got exception: {context.exception}"
)
@parameterized.expand(params, name_func=parameterized_custom_name_func)
def test_can_resume_training_normal(self, stage, dtype):
@parameterized.expand(params_with_optims_and_schedulers, name_func=parameterized_custom_name_func)
def test_can_resume_training_normal(self, stage, dtype, optim, scheduler):
# adapted from TrainerIntegrationTest.test_can_resume_training
# test normal resume for each stage separately, error-handling is tested in a different test
# ToDo: Currently, hf_optim + hf_scheduler resumes with the correct states and
# also has same losses for few steps but then slowly diverges. Need to figure it out.
if optim == HF_OPTIM and scheduler == HF_SCHEDULER:
return
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
ds_config_dict = self.get_config_dict(stage)
if dtype == FP16:
@@ -652,6 +668,12 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
if stage == ZERO3:
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
if optim == HF_OPTIM:
del ds_config_dict["optimizer"]
if scheduler == HF_SCHEDULER:
del ds_config_dict["scheduler"]
kwargs = {
"output_dir": output_dir,
"train_len": 128,