From 9db31ea58579cf441bc0cf978ecf917a289fdc39 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 7 Apr 2025 11:36:44 +0200 Subject: [PATCH] Fix deepspeed with quantization (#37324) * Update modeling_utils.py * Update modeling_utils.py --- src/transformers/modeling_utils.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6a3286cbc9..67266d558d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3719,19 +3719,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix @classmethod def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): - # With deepspeed, we cannot initialize the model on meta device if is_deepspeed_zero3_enabled(): init_contexts = [no_init_weights()] + # We cannot initialize the model on meta device with deepspeed when not quantized if not is_quantized and not _is_ds_init_called: logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") - init_contexts.extend( - [ - deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), - set_zero3_state(), - ] - ) + init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]) elif is_quantized: - init_contexts.append(set_quantized_state()) + init_contexts.extend([init_empty_weights(), set_quantized_state()]) else: init_contexts = [no_init_weights(), init_empty_weights()] @@ -4800,7 +4795,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix continue map_location = "cpu" - if shard_file.endswith(".safetensors") and not is_hqq_or_bnb and not is_deepspeed_zero3_enabled(): + if ( + shard_file.endswith(".safetensors") + and not is_hqq_or_bnb + and not (is_deepspeed_zero3_enabled() and not is_quantized) + ): map_location = "meta" elif ( device_map is not None @@ -4822,7 +4821,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Fix the key names state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} - if is_deepspeed_zero3_enabled(): + if is_deepspeed_zero3_enabled() and not is_quantized: error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict) # Skip it with fsdp on ranks other than 0 elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):