From 7652804d237fb8768f0f0b8129a05e4f0576114b Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 12 Mar 2025 11:40:46 +0100 Subject: [PATCH] Fix bnb regression due to empty state dict (#36663) fix --- src/transformers/modeling_utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c4cf20c060..8462fb84b1 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -833,7 +833,7 @@ def _load_state_dict_into_meta_model( bin_state_dict = None if shard_file.endswith(".safetensors"): file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) - else: + elif shard_file.endswith(".bin"): map_location = "cpu" if ( device_map is not None @@ -848,6 +848,13 @@ def _load_state_dict_into_meta_model( is_quantized = hf_quantizer is not None + # get full state dict + if is_quantized: + if shard_file.endswith(".safetensors"): + full_state_dict = load_state_dict(shard_file, map_location="cpu") + elif shard_file.endswith(".bin"): + full_state_dict = bin_state_dict + for serialized_param_name, empty_param in state_dict.items(): # serialized_param_name is the raw, serialized name # fixed_param_name is the model's equivalent @@ -912,7 +919,7 @@ def _load_state_dict_into_meta_model( model, param, fixed_param_name, - state_dict, + full_state_dict, param_device=param_device, device_map=device_map, ) @@ -928,7 +935,7 @@ def _load_state_dict_into_meta_model( ) else: hf_quantizer.create_quantized_param( - model, param, fixed_param_name, param_device, state_dict, unexpected_keys + model, param, fixed_param_name, param_device, full_state_dict, unexpected_keys ) # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU # and then cast it to CPU to avoid excessive memory usage on each GPU