deepspeed resume from ckpt fixes and adding support for deepspeed optimizer and HF scheduler (#25863)
* Add support for deepspeed optimizer and HF scheduler * fix bug * fix the import * fix issue with deepspeed scheduler saving for hf optim + hf scheduler scenario * fix loading of hf scheduler when loading deepspeed checkpoint * fix import of `DeepSpeedSchedulerWrapper` * add tests * add the comment and skip the failing tests * address comment
This commit is contained in:
committed by
GitHub
parent
1110b565d6
commit
6bc517ccd4
@@ -26,6 +26,8 @@ from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..optimization import get_scheduler
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -274,7 +276,7 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
|
||||
# 1. DS scheduler + DS optimizer: Yes
|
||||
# 2. HF scheduler + HF optimizer: Mostly*
|
||||
# 3. DS scheduler + HF optimizer: Mostly*
|
||||
# 4. HF scheduler + DS optimizer: No
|
||||
# 4. HF scheduler + DS optimizer: Yes
|
||||
#
|
||||
# Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)
|
||||
|
||||
@@ -304,11 +306,18 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
|
||||
lr_scheduler = DummyScheduler(optimizer)
|
||||
else:
|
||||
if isinstance(optimizer, DummyOptim):
|
||||
raise ValueError(
|
||||
"Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. "
|
||||
"Please configure a scheduler in the DeepSpeed config."
|
||||
)
|
||||
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
|
||||
|
||||
def _lr_scheduler_callable(optimizer):
|
||||
return get_scheduler(
|
||||
trainer.args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=trainer.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
|
||||
lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
|
||||
else:
|
||||
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
|
||||
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
|
||||
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
||||
from .dependency_versions_check import dep_version_check
|
||||
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
|
||||
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint
|
||||
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
|
||||
from .modelcard import TrainingSummary
|
||||
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
||||
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
||||
@@ -212,6 +212,9 @@ if is_accelerate_available():
|
||||
save_fsdp_optimizer,
|
||||
)
|
||||
|
||||
if is_deepspeed_available():
|
||||
from accelerate.utils import DeepSpeedSchedulerWrapper
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import optuna
|
||||
@@ -2362,7 +2365,14 @@ class Trainer:
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
|
||||
# Save SCHEDULER & SCALER
|
||||
if self.args.should_save and not self.is_deepspeed_enabled and not is_torch_tpu_available():
|
||||
is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
|
||||
self.lr_scheduler, DeepSpeedSchedulerWrapper
|
||||
)
|
||||
if (
|
||||
self.args.should_save
|
||||
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
|
||||
and not is_torch_tpu_available()
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
@@ -2428,6 +2438,10 @@ class Trainer:
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
||||
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
return
|
||||
|
||||
checkpoint_file_exists = (
|
||||
|
||||
@@ -136,6 +136,14 @@ ZERO3 = "zero3"
|
||||
FP16 = "fp16"
|
||||
BF16 = "bf16"
|
||||
|
||||
HF_OPTIM = "hf_optim"
|
||||
HF_SCHEDULER = "hf_scheduler"
|
||||
DS_OPTIM = "ds_optim"
|
||||
DS_SCHEDULER = "ds_scheduler"
|
||||
|
||||
optims = [HF_OPTIM, DS_OPTIM]
|
||||
schedulers = [HF_SCHEDULER, DS_SCHEDULER]
|
||||
|
||||
stages = [ZERO2, ZERO3]
|
||||
if is_torch_bf16_gpu_available():
|
||||
dtypes = [FP16, BF16]
|
||||
@@ -153,6 +161,8 @@ def parameterized_custom_name_func(func, param_num, param):
|
||||
# Cartesian-product of zero stages with models to test
|
||||
params = list(itertools.product(stages, dtypes))
|
||||
|
||||
params_with_optims_and_schedulers = list(itertools.product(stages, dtypes, optims, schedulers))
|
||||
|
||||
|
||||
@require_deepspeed
|
||||
@require_torch_gpu
|
||||
@@ -640,10 +650,16 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
|
||||
"Can't find a valid checkpoint at" in str(context.exception), f"got exception: {context.exception}"
|
||||
)
|
||||
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_can_resume_training_normal(self, stage, dtype):
|
||||
@parameterized.expand(params_with_optims_and_schedulers, name_func=parameterized_custom_name_func)
|
||||
def test_can_resume_training_normal(self, stage, dtype, optim, scheduler):
|
||||
# adapted from TrainerIntegrationTest.test_can_resume_training
|
||||
# test normal resume for each stage separately, error-handling is tested in a different test
|
||||
|
||||
# ToDo: Currently, hf_optim + hf_scheduler resumes with the correct states and
|
||||
# also has same losses for few steps but then slowly diverges. Need to figure it out.
|
||||
if optim == HF_OPTIM and scheduler == HF_SCHEDULER:
|
||||
return
|
||||
|
||||
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
|
||||
ds_config_dict = self.get_config_dict(stage)
|
||||
if dtype == FP16:
|
||||
@@ -652,6 +668,12 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
|
||||
if stage == ZERO3:
|
||||
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
|
||||
|
||||
if optim == HF_OPTIM:
|
||||
del ds_config_dict["optimizer"]
|
||||
|
||||
if scheduler == HF_SCHEDULER:
|
||||
del ds_config_dict["scheduler"]
|
||||
|
||||
kwargs = {
|
||||
"output_dir": output_dir,
|
||||
"train_len": 128,
|
||||
|
||||
Reference in New Issue
Block a user