Fix loading zero3 weights (#36455)
* Check if fixes * Fix zero3 loading * Quality * Fix marc nit * Add fast tests * Migrate to integrations.deepspeed rather than modeling_utils * Style
This commit is contained in:
@@ -27,6 +27,7 @@ from ..utils import is_accelerate_available, is_torch_available, is_torch_mlu_av
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -305,6 +306,57 @@ def deepspeed_config():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
|
||||||
|
"""
|
||||||
|
Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers`
|
||||||
|
tensor parallelism API.
|
||||||
|
|
||||||
|
Nearly identical code to PyTorch's `_load_from_state_dict`
|
||||||
|
"""
|
||||||
|
# copy state_dict so `_load_state_dict_into_zero3_model` can modify it
|
||||||
|
metadata = getattr(state_dict, "_metadata", None)
|
||||||
|
state_dict = state_dict.copy()
|
||||||
|
if metadata is not None:
|
||||||
|
state_dict._metadata = metadata
|
||||||
|
|
||||||
|
error_msgs = []
|
||||||
|
|
||||||
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||||
|
# so we need to apply the function recursively.
|
||||||
|
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
|
||||||
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
|
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
|
||||||
|
|
||||||
|
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||||
|
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
||||||
|
# state_dict
|
||||||
|
if is_deepspeed_zero3_enabled() and len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
|
# In sharded models, each shard has only part of the full state_dict, so only gather
|
||||||
|
# parameters that are in the current state_dict.
|
||||||
|
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
|
||||||
|
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
|
||||||
|
if len(params_to_gather) > 0:
|
||||||
|
# because zero3 puts placeholders in model params, this context
|
||||||
|
# manager gathers (unpartitions) the params of the current layer, then loads from
|
||||||
|
# the state dict and then re-partitions them again
|
||||||
|
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
module._load_from_state_dict(*args)
|
||||||
|
|
||||||
|
for name, child in module._modules.items():
|
||||||
|
if child is not None:
|
||||||
|
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
|
||||||
|
|
||||||
|
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
|
||||||
|
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
|
||||||
|
# it's safe to delete it.
|
||||||
|
del state_dict
|
||||||
|
|
||||||
|
return error_msgs
|
||||||
|
|
||||||
|
|
||||||
def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
|
def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
|
||||||
"""
|
"""
|
||||||
A convenience wrapper that deals with optimizer and lr scheduler configuration.
|
A convenience wrapper that deals with optimizer and lr scheduler configuration.
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ from .configuration_utils import PretrainedConfig
|
|||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
||||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||||
|
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||||
from .integrations.flash_attention import flash_attention_forward
|
from .integrations.flash_attention import flash_attention_forward
|
||||||
from .integrations.flex_attention import flex_attention_forward
|
from .integrations.flex_attention import flex_attention_forward
|
||||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||||
@@ -4918,7 +4919,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
mismatched_names = [name for name, _, _ in mismatched_keys]
|
mismatched_names = [name for name, _, _ in mismatched_keys]
|
||||||
fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
|
fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
|
||||||
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
|
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
|
||||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
error_msgs += _load_state_dict_into_zero3_model(
|
||||||
|
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||||
else:
|
else:
|
||||||
# This should always be a list but, just to be sure.
|
# This should always be a list but, just to be sure.
|
||||||
if not isinstance(resolved_archive_file, list):
|
if not isinstance(resolved_archive_file, list):
|
||||||
@@ -5009,7 +5016,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
model_to_load, state_dict, start_prefix
|
model_to_load, state_dict, start_prefix
|
||||||
)
|
)
|
||||||
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
|
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
|
||||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
if is_deepspeed_zero3_enabled():
|
||||||
|
error_msgs += _load_state_dict_into_zero3_model(
|
||||||
|
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||||
# force memory release
|
# force memory release
|
||||||
del state_dict
|
del state_dict
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|||||||
@@ -170,7 +170,6 @@ params_with_optims_and_schedulers = list(itertools.product(stages, dtypes, optim
|
|||||||
|
|
||||||
|
|
||||||
@require_deepspeed
|
@require_deepspeed
|
||||||
@require_torch_accelerator
|
|
||||||
class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||||
"""
|
"""
|
||||||
Testing non-Trainer DeepSpeed integration
|
Testing non-Trainer DeepSpeed integration
|
||||||
@@ -194,6 +193,42 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
# reset the ds config global so that tests state doesn't leak
|
# reset the ds config global so that tests state doesn't leak
|
||||||
unset_hf_deepspeed_config()
|
unset_hf_deepspeed_config()
|
||||||
|
|
||||||
|
def test_init_zero3(self):
|
||||||
|
# test that zero.Init() works correctly
|
||||||
|
ds_config = {
|
||||||
|
"train_batch_size": 1,
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
dschf = HfDeepSpeedConfig(ds_config)
|
||||||
|
|
||||||
|
self.assertTrue(dschf.is_zero3())
|
||||||
|
self.assertTrue(is_deepspeed_zero3_enabled())
|
||||||
|
|
||||||
|
with LoggingLevel(logging.INFO):
|
||||||
|
with mockenv_context(**self.dist_env_1_gpu):
|
||||||
|
logger = logging.get_logger("transformers.modeling_utils")
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
AutoModel.from_pretrained(T5_TINY)
|
||||||
|
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||||
|
|
||||||
|
# now remove zero optimization
|
||||||
|
del ds_config["zero_optimization"]
|
||||||
|
dschf = HfDeepSpeedConfig(ds_config)
|
||||||
|
|
||||||
|
self.assertFalse(dschf.is_zero3())
|
||||||
|
self.assertFalse(is_deepspeed_zero3_enabled())
|
||||||
|
|
||||||
|
with LoggingLevel(logging.INFO):
|
||||||
|
with mockenv_context(**self.dist_env_1_gpu):
|
||||||
|
logger = logging.get_logger("transformers.modeling_utils")
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
AutoModel.from_pretrained(T5_TINY)
|
||||||
|
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||||
|
|
||||||
|
@require_torch_accelerator
|
||||||
def test_init_zero3_fp16(self):
|
def test_init_zero3_fp16(self):
|
||||||
# test that zero.Init() works correctly under zero3/fp16
|
# test that zero.Init() works correctly under zero3/fp16
|
||||||
ds_config = {
|
ds_config = {
|
||||||
@@ -201,6 +236,9 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 3,
|
"stage": 3,
|
||||||
},
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
dschf = HfDeepSpeedConfig(ds_config)
|
dschf = HfDeepSpeedConfig(ds_config)
|
||||||
|
|||||||
Reference in New Issue
Block a user