Fix deepspeed loading (part 2) (#37306)

* fix

* Update modeling_utils.py

* Update modeling_utils.py

* oups remove print
This commit is contained in:
Cyril Vallez
2025-04-05 20:41:42 +02:00
committed by GitHub
parent 84aa13dd85
commit e94571580b

View File

@@ -3723,18 +3723,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
@classmethod @classmethod
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called: # With deepspeed, we cannot initialize the model on meta device
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") if is_deepspeed_zero3_enabled():
init_contexts = [ init_contexts = [no_init_weights()]
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), if not is_quantized and not _is_ds_init_called:
set_zero3_state(), logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
no_init_weights(), init_contexts.extend(
] [
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
set_zero3_state(),
]
)
elif is_quantized:
init_contexts.append(set_quantized_state())
else: else:
init_contexts = [no_init_weights(), init_empty_weights()] init_contexts = [no_init_weights(), init_empty_weights()]
if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())
return init_contexts return init_contexts
@classmethod @classmethod