Fix loading zero3 weights (#36455)
* Check if fixes * Fix zero3 loading * Quality * Fix marc nit * Add fast tests * Migrate to integrations.deepspeed rather than modeling_utils * Style
This commit is contained in:
@@ -27,6 +27,7 @@ from ..utils import is_accelerate_available, is_torch_available, is_torch_mlu_av
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -305,6 +306,57 @@ def deepspeed_config():
|
||||
return None
|
||||
|
||||
|
||||
def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
|
||||
"""
|
||||
Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers`
|
||||
tensor parallelism API.
|
||||
|
||||
Nearly identical code to PyTorch's `_load_from_state_dict`
|
||||
"""
|
||||
# copy state_dict so `_load_state_dict_into_zero3_model` can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
error_msgs = []
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
|
||||
|
||||
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
||||
# state_dict
|
||||
if is_deepspeed_zero3_enabled() and len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||
import deepspeed
|
||||
|
||||
# In sharded models, each shard has only part of the full state_dict, so only gather
|
||||
# parameters that are in the current state_dict.
|
||||
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
|
||||
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
|
||||
if len(params_to_gather) > 0:
|
||||
# because zero3 puts placeholders in model params, this context
|
||||
# manager gathers (unpartitions) the params of the current layer, then loads from
|
||||
# the state dict and then re-partitions them again
|
||||
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
|
||||
|
||||
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
|
||||
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
|
||||
# it's safe to delete it.
|
||||
del state_dict
|
||||
|
||||
return error_msgs
|
||||
|
||||
|
||||
def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
|
||||
"""
|
||||
A convenience wrapper that deals with optimizer and lr scheduler configuration.
|
||||
|
||||
@@ -50,6 +50,7 @@ from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||
from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
@@ -4918,7 +4919,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
mismatched_names = [name for name, _, _ in mismatched_keys]
|
||||
fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
|
||||
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
|
||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
error_msgs += _load_state_dict_into_zero3_model(
|
||||
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||
)
|
||||
else:
|
||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||
else:
|
||||
# This should always be a list but, just to be sure.
|
||||
if not isinstance(resolved_archive_file, list):
|
||||
@@ -5009,7 +5016,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model_to_load, state_dict, start_prefix
|
||||
)
|
||||
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
|
||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
error_msgs += _load_state_dict_into_zero3_model(
|
||||
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||
)
|
||||
else:
|
||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||
# force memory release
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
@@ -170,7 +170,6 @@ params_with_optims_and_schedulers = list(itertools.product(stages, dtypes, optim
|
||||
|
||||
|
||||
@require_deepspeed
|
||||
@require_torch_accelerator
|
||||
class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
"""
|
||||
Testing non-Trainer DeepSpeed integration
|
||||
@@ -194,6 +193,42 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
# reset the ds config global so that tests state doesn't leak
|
||||
unset_hf_deepspeed_config()
|
||||
|
||||
def test_init_zero3(self):
|
||||
# test that zero.Init() works correctly
|
||||
ds_config = {
|
||||
"train_batch_size": 1,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
},
|
||||
}
|
||||
|
||||
dschf = HfDeepSpeedConfig(ds_config)
|
||||
|
||||
self.assertTrue(dschf.is_zero3())
|
||||
self.assertTrue(is_deepspeed_zero3_enabled())
|
||||
|
||||
with LoggingLevel(logging.INFO):
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
AutoModel.from_pretrained(T5_TINY)
|
||||
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||
|
||||
# now remove zero optimization
|
||||
del ds_config["zero_optimization"]
|
||||
dschf = HfDeepSpeedConfig(ds_config)
|
||||
|
||||
self.assertFalse(dschf.is_zero3())
|
||||
self.assertFalse(is_deepspeed_zero3_enabled())
|
||||
|
||||
with LoggingLevel(logging.INFO):
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
with CaptureLogger(logger) as cl:
|
||||
AutoModel.from_pretrained(T5_TINY)
|
||||
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_init_zero3_fp16(self):
|
||||
# test that zero.Init() works correctly under zero3/fp16
|
||||
ds_config = {
|
||||
@@ -201,6 +236,9 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
|
||||
dschf = HfDeepSpeedConfig(ds_config)
|
||||
|
||||
Reference in New Issue
Block a user