Fix loading zero3 weights (#36455)
* Check if fixes * Fix zero3 loading * Quality * Fix marc nit * Add fast tests * Migrate to integrations.deepspeed rather than modeling_utils * Style
This commit is contained in:
@@ -50,6 +50,7 @@ from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||
from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
@@ -4918,7 +4919,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
mismatched_names = [name for name, _, _ in mismatched_keys]
|
||||
fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
|
||||
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
|
||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
error_msgs += _load_state_dict_into_zero3_model(
|
||||
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||
)
|
||||
else:
|
||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||
else:
|
||||
# This should always be a list but, just to be sure.
|
||||
if not isinstance(resolved_archive_file, list):
|
||||
@@ -5009,7 +5016,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model_to_load, state_dict, start_prefix
|
||||
)
|
||||
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
|
||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
error_msgs += _load_state_dict_into_zero3_model(
|
||||
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||
)
|
||||
else:
|
||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||
# force memory release
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
Reference in New Issue
Block a user