deepspeed resume from ckpt fixes and adding support for deepspeed optimizer and HF scheduler (#25863)
* Add support for deepspeed optimizer and HF scheduler * fix bug * fix the import * fix issue with deepspeed scheduler saving for hf optim + hf scheduler scenario * fix loading of hf scheduler when loading deepspeed checkpoint * fix import of `DeepSpeedSchedulerWrapper` * add tests * add the comment and skip the failing tests * address comment
This commit is contained in:
committed by
GitHub
parent
1110b565d6
commit
6bc517ccd4
@@ -26,6 +26,8 @@ from ..utils import is_accelerate_available, is_torch_available, logging
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ..optimization import get_scheduler
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -274,7 +276,7 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
|
|||||||
# 1. DS scheduler + DS optimizer: Yes
|
# 1. DS scheduler + DS optimizer: Yes
|
||||||
# 2. HF scheduler + HF optimizer: Mostly*
|
# 2. HF scheduler + HF optimizer: Mostly*
|
||||||
# 3. DS scheduler + HF optimizer: Mostly*
|
# 3. DS scheduler + HF optimizer: Mostly*
|
||||||
# 4. HF scheduler + DS optimizer: No
|
# 4. HF scheduler + DS optimizer: Yes
|
||||||
#
|
#
|
||||||
# Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)
|
# Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)
|
||||||
|
|
||||||
@@ -304,11 +306,18 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
|
|||||||
lr_scheduler = DummyScheduler(optimizer)
|
lr_scheduler = DummyScheduler(optimizer)
|
||||||
else:
|
else:
|
||||||
if isinstance(optimizer, DummyOptim):
|
if isinstance(optimizer, DummyOptim):
|
||||||
raise ValueError(
|
|
||||||
"Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. "
|
def _lr_scheduler_callable(optimizer):
|
||||||
"Please configure a scheduler in the DeepSpeed config."
|
return get_scheduler(
|
||||||
)
|
trainer.args.lr_scheduler_type,
|
||||||
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=trainer.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
|
||||||
|
else:
|
||||||
|
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
|
||||||
|
|
||||||
return optimizer, lr_scheduler
|
return optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
|
|||||||
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
||||||
from .dependency_versions_check import dep_version_check
|
from .dependency_versions_check import dep_version_check
|
||||||
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
|
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
|
||||||
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint
|
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
|
||||||
from .modelcard import TrainingSummary
|
from .modelcard import TrainingSummary
|
||||||
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
||||||
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
||||||
@@ -212,6 +212,9 @@ if is_accelerate_available():
|
|||||||
save_fsdp_optimizer,
|
save_fsdp_optimizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_deepspeed_available():
|
||||||
|
from accelerate.utils import DeepSpeedSchedulerWrapper
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import optuna
|
import optuna
|
||||||
@@ -2362,7 +2365,14 @@ class Trainer:
|
|||||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||||
|
|
||||||
# Save SCHEDULER & SCALER
|
# Save SCHEDULER & SCALER
|
||||||
if self.args.should_save and not self.is_deepspeed_enabled and not is_torch_tpu_available():
|
is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
|
||||||
|
self.lr_scheduler, DeepSpeedSchedulerWrapper
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self.args.should_save
|
||||||
|
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
|
||||||
|
and not is_torch_tpu_available()
|
||||||
|
):
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
@@ -2428,6 +2438,10 @@ class Trainer:
|
|||||||
|
|
||||||
if self.is_deepspeed_enabled:
|
if self.is_deepspeed_enabled:
|
||||||
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
||||||
|
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
return
|
return
|
||||||
|
|
||||||
checkpoint_file_exists = (
|
checkpoint_file_exists = (
|
||||||
|
|||||||
@@ -136,6 +136,14 @@ ZERO3 = "zero3"
|
|||||||
FP16 = "fp16"
|
FP16 = "fp16"
|
||||||
BF16 = "bf16"
|
BF16 = "bf16"
|
||||||
|
|
||||||
|
HF_OPTIM = "hf_optim"
|
||||||
|
HF_SCHEDULER = "hf_scheduler"
|
||||||
|
DS_OPTIM = "ds_optim"
|
||||||
|
DS_SCHEDULER = "ds_scheduler"
|
||||||
|
|
||||||
|
optims = [HF_OPTIM, DS_OPTIM]
|
||||||
|
schedulers = [HF_SCHEDULER, DS_SCHEDULER]
|
||||||
|
|
||||||
stages = [ZERO2, ZERO3]
|
stages = [ZERO2, ZERO3]
|
||||||
if is_torch_bf16_gpu_available():
|
if is_torch_bf16_gpu_available():
|
||||||
dtypes = [FP16, BF16]
|
dtypes = [FP16, BF16]
|
||||||
@@ -153,6 +161,8 @@ def parameterized_custom_name_func(func, param_num, param):
|
|||||||
# Cartesian-product of zero stages with models to test
|
# Cartesian-product of zero stages with models to test
|
||||||
params = list(itertools.product(stages, dtypes))
|
params = list(itertools.product(stages, dtypes))
|
||||||
|
|
||||||
|
params_with_optims_and_schedulers = list(itertools.product(stages, dtypes, optims, schedulers))
|
||||||
|
|
||||||
|
|
||||||
@require_deepspeed
|
@require_deepspeed
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@@ -640,10 +650,16 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
|
|||||||
"Can't find a valid checkpoint at" in str(context.exception), f"got exception: {context.exception}"
|
"Can't find a valid checkpoint at" in str(context.exception), f"got exception: {context.exception}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
@parameterized.expand(params_with_optims_and_schedulers, name_func=parameterized_custom_name_func)
|
||||||
def test_can_resume_training_normal(self, stage, dtype):
|
def test_can_resume_training_normal(self, stage, dtype, optim, scheduler):
|
||||||
# adapted from TrainerIntegrationTest.test_can_resume_training
|
# adapted from TrainerIntegrationTest.test_can_resume_training
|
||||||
# test normal resume for each stage separately, error-handling is tested in a different test
|
# test normal resume for each stage separately, error-handling is tested in a different test
|
||||||
|
|
||||||
|
# ToDo: Currently, hf_optim + hf_scheduler resumes with the correct states and
|
||||||
|
# also has same losses for few steps but then slowly diverges. Need to figure it out.
|
||||||
|
if optim == HF_OPTIM and scheduler == HF_SCHEDULER:
|
||||||
|
return
|
||||||
|
|
||||||
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
|
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
|
||||||
ds_config_dict = self.get_config_dict(stage)
|
ds_config_dict = self.get_config_dict(stage)
|
||||||
if dtype == FP16:
|
if dtype == FP16:
|
||||||
@@ -652,6 +668,12 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
|
|||||||
if stage == ZERO3:
|
if stage == ZERO3:
|
||||||
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
|
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
|
||||||
|
|
||||||
|
if optim == HF_OPTIM:
|
||||||
|
del ds_config_dict["optimizer"]
|
||||||
|
|
||||||
|
if scheduler == HF_SCHEDULER:
|
||||||
|
del ds_config_dict["scheduler"]
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"output_dir": output_dir,
|
"output_dir": output_dir,
|
||||||
"train_len": 128,
|
"train_len": 128,
|
||||||
|
|||||||
Reference in New Issue
Block a user