From e94571580ba9d0c71feaf489520383c83e167d40 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Sat, 5 Apr 2025 20:41:42 +0200 Subject: [PATCH] Fix deepspeed loading (part 2) (#37306) * fix * Update modeling_utils.py * Update modeling_utils.py * oups remove print --- src/transformers/modeling_utils.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index cbeb857906..2e0a389d74 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3723,18 +3723,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix @classmethod def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): - if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called: - logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") - init_contexts = [ - deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), - set_zero3_state(), - no_init_weights(), - ] + # With deepspeed, we cannot initialize the model on meta device + if is_deepspeed_zero3_enabled(): + init_contexts = [no_init_weights()] + if not is_quantized and not _is_ds_init_called: + logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") + init_contexts.extend( + [ + deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), + set_zero3_state(), + ] + ) + elif is_quantized: + init_contexts.append(set_quantized_state()) else: init_contexts = [no_init_weights(), init_empty_weights()] - if is_deepspeed_zero3_enabled() and is_quantized: - init_contexts.append(set_quantized_state()) return init_contexts @classmethod