Fix deepspeed loading (part 2) (#37306)
* fix * Update modeling_utils.py * Update modeling_utils.py * oups remove print
This commit is contained in:
@@ -3723,18 +3723,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
||||||
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
|
# With deepspeed, we cannot initialize the model on meta device
|
||||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
if is_deepspeed_zero3_enabled():
|
||||||
init_contexts = [
|
init_contexts = [no_init_weights()]
|
||||||
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
|
if not is_quantized and not _is_ds_init_called:
|
||||||
set_zero3_state(),
|
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||||
no_init_weights(),
|
init_contexts.extend(
|
||||||
]
|
[
|
||||||
|
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
|
||||||
|
set_zero3_state(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif is_quantized:
|
||||||
|
init_contexts.append(set_quantized_state())
|
||||||
else:
|
else:
|
||||||
init_contexts = [no_init_weights(), init_empty_weights()]
|
init_contexts = [no_init_weights(), init_empty_weights()]
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled() and is_quantized:
|
|
||||||
init_contexts.append(set_quantized_state())
|
|
||||||
return init_contexts
|
return init_contexts
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user