Fix loading zero3 weights (#36455)

* Check if fixes

* Fix zero3 loading

* Quality

* Fix marc nit

* Add fast tests

* Migrate to integrations.deepspeed rather than modeling_utils

* Style
This commit is contained in:
Zach Mueller
2025-03-03 09:05:58 -05:00
committed by GitHub
parent dcbdf7e962
commit 4d8259d245
3 changed files with 105 additions and 3 deletions

View File

@@ -27,6 +27,7 @@ from ..utils import is_accelerate_available, is_torch_available, is_torch_mlu_av
if is_torch_available(): if is_torch_available():
import torch import torch
from torch import nn
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -305,6 +306,57 @@ def deepspeed_config():
return None return None
def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
"""
Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers`
tensor parallelism API.
Nearly identical code to PyTorch's `_load_from_state_dict`
"""
# copy state_dict so `_load_state_dict_into_zero3_model` can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if is_deepspeed_zero3_enabled() and len([key for key in state_dict if key.startswith(prefix)]) > 0:
import deepspeed
# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
return error_msgs
def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters): def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
""" """
A convenience wrapper that deals with optimizer and lr scheduler configuration. A convenience wrapper that deals with optimizer and lr scheduler configuration.

View File

@@ -50,6 +50,7 @@ from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.flash_attention import flash_attention_forward from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward from .integrations.sdpa_attention import sdpa_attention_forward
@@ -4918,7 +4919,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mismatched_names = [name for name, _, _ in mismatched_keys] mismatched_names = [name for name, _, _ in mismatched_keys]
fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names} fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict) fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
if is_deepspeed_zero3_enabled():
error_msgs += _load_state_dict_into_zero3_model(
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)
else:
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
else: else:
# This should always be a list but, just to be sure. # This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list): if not isinstance(resolved_archive_file, list):
@@ -5009,7 +5016,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model_to_load, state_dict, start_prefix model_to_load, state_dict, start_prefix
) )
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict) fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers) if is_deepspeed_zero3_enabled():
error_msgs += _load_state_dict_into_zero3_model(
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)
else:
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
# force memory release # force memory release
del state_dict del state_dict
gc.collect() gc.collect()

View File

@@ -170,7 +170,6 @@ params_with_optims_and_schedulers = list(itertools.product(stages, dtypes, optim
@require_deepspeed @require_deepspeed
@require_torch_accelerator
class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
""" """
Testing non-Trainer DeepSpeed integration Testing non-Trainer DeepSpeed integration
@@ -194,6 +193,42 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
# reset the ds config global so that tests state doesn't leak # reset the ds config global so that tests state doesn't leak
unset_hf_deepspeed_config() unset_hf_deepspeed_config()
def test_init_zero3(self):
# test that zero.Init() works correctly
ds_config = {
"train_batch_size": 1,
"zero_optimization": {
"stage": 3,
},
}
dschf = HfDeepSpeedConfig(ds_config)
self.assertTrue(dschf.is_zero3())
self.assertTrue(is_deepspeed_zero3_enabled())
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
AutoModel.from_pretrained(T5_TINY)
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
# now remove zero optimization
del ds_config["zero_optimization"]
dschf = HfDeepSpeedConfig(ds_config)
self.assertFalse(dschf.is_zero3())
self.assertFalse(is_deepspeed_zero3_enabled())
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
AutoModel.from_pretrained(T5_TINY)
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
@require_torch_accelerator
def test_init_zero3_fp16(self): def test_init_zero3_fp16(self):
# test that zero.Init() works correctly under zero3/fp16 # test that zero.Init() works correctly under zero3/fp16
ds_config = { ds_config = {
@@ -201,6 +236,9 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
"zero_optimization": { "zero_optimization": {
"stage": 3, "stage": 3,
}, },
"fp16": {
"enabled": True,
},
} }
dschf = HfDeepSpeedConfig(ds_config) dschf = HfDeepSpeedConfig(ds_config)