Fix deepspeed with quantization (#37324)
* Update modeling_utils.py * Update modeling_utils.py
This commit is contained in:
@@ -3719,19 +3719,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
||||||
# With deepspeed, we cannot initialize the model on meta device
|
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
init_contexts = [no_init_weights()]
|
init_contexts = [no_init_weights()]
|
||||||
|
# We cannot initialize the model on meta device with deepspeed when not quantized
|
||||||
if not is_quantized and not _is_ds_init_called:
|
if not is_quantized and not _is_ds_init_called:
|
||||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||||
init_contexts.extend(
|
init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
|
||||||
[
|
|
||||||
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
|
|
||||||
set_zero3_state(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
elif is_quantized:
|
elif is_quantized:
|
||||||
init_contexts.append(set_quantized_state())
|
init_contexts.extend([init_empty_weights(), set_quantized_state()])
|
||||||
else:
|
else:
|
||||||
init_contexts = [no_init_weights(), init_empty_weights()]
|
init_contexts = [no_init_weights(), init_empty_weights()]
|
||||||
|
|
||||||
@@ -4800,7 +4795,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
map_location = "cpu"
|
map_location = "cpu"
|
||||||
if shard_file.endswith(".safetensors") and not is_hqq_or_bnb and not is_deepspeed_zero3_enabled():
|
if (
|
||||||
|
shard_file.endswith(".safetensors")
|
||||||
|
and not is_hqq_or_bnb
|
||||||
|
and not (is_deepspeed_zero3_enabled() and not is_quantized)
|
||||||
|
):
|
||||||
map_location = "meta"
|
map_location = "meta"
|
||||||
elif (
|
elif (
|
||||||
device_map is not None
|
device_map is not None
|
||||||
@@ -4822,7 +4821,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Fix the key names
|
# Fix the key names
|
||||||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||||
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
|
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
|
||||||
# Skip it with fsdp on ranks other than 0
|
# Skip it with fsdp on ranks other than 0
|
||||||
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
||||||
|
|||||||
Reference in New Issue
Block a user