Fix bnb regression due to empty state dict (#36663)

fix
This commit is contained in:
Marc Sun
2025-03-12 11:40:46 +01:00
committed by GitHub
parent 994cad2790
commit 7652804d23

View File

@@ -833,7 +833,7 @@ def _load_state_dict_into_meta_model(
bin_state_dict = None bin_state_dict = None
if shard_file.endswith(".safetensors"): if shard_file.endswith(".safetensors"):
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
else: elif shard_file.endswith(".bin"):
map_location = "cpu" map_location = "cpu"
if ( if (
device_map is not None device_map is not None
@@ -848,6 +848,13 @@ def _load_state_dict_into_meta_model(
is_quantized = hf_quantizer is not None is_quantized = hf_quantizer is not None
# get full state dict
if is_quantized:
if shard_file.endswith(".safetensors"):
full_state_dict = load_state_dict(shard_file, map_location="cpu")
elif shard_file.endswith(".bin"):
full_state_dict = bin_state_dict
for serialized_param_name, empty_param in state_dict.items(): for serialized_param_name, empty_param in state_dict.items():
# serialized_param_name is the raw, serialized name # serialized_param_name is the raw, serialized name
# fixed_param_name is the model's equivalent # fixed_param_name is the model's equivalent
@@ -912,7 +919,7 @@ def _load_state_dict_into_meta_model(
model, model,
param, param,
fixed_param_name, fixed_param_name,
state_dict, full_state_dict,
param_device=param_device, param_device=param_device,
device_map=device_map, device_map=device_map,
) )
@@ -928,7 +935,7 @@ def _load_state_dict_into_meta_model(
) )
else: else:
hf_quantizer.create_quantized_param( hf_quantizer.create_quantized_param(
model, param, fixed_param_name, param_device, state_dict, unexpected_keys model, param, fixed_param_name, param_device, full_state_dict, unexpected_keys
) )
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU # and then cast it to CPU to avoid excessive memory usage on each GPU