Fix loading zero3 weights (#36455)

* Check if fixes

* Fix zero3 loading

* Quality

* Fix marc nit

* Add fast tests

* Migrate to integrations.deepspeed rather than modeling_utils

* Style
This commit is contained in:
Zach Mueller
2025-03-03 09:05:58 -05:00
committed by GitHub
parent dcbdf7e962
commit 4d8259d245
3 changed files with 105 additions and 3 deletions

View File

@@ -50,6 +50,7 @@ from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
@@ -4918,7 +4919,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mismatched_names = [name for name, _, _ in mismatched_keys]
fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
if is_deepspeed_zero3_enabled():
error_msgs += _load_state_dict_into_zero3_model(
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)
else:
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
else:
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
@@ -5009,7 +5016,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model_to_load, state_dict, start_prefix
)
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
if is_deepspeed_zero3_enabled():
error_msgs += _load_state_dict_into_zero3_model(
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)
else:
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
# force memory release
del state_dict
gc.collect()