@@ -833,7 +833,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
bin_state_dict = None
|
bin_state_dict = None
|
||||||
if shard_file.endswith(".safetensors"):
|
if shard_file.endswith(".safetensors"):
|
||||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||||
else:
|
elif shard_file.endswith(".bin"):
|
||||||
map_location = "cpu"
|
map_location = "cpu"
|
||||||
if (
|
if (
|
||||||
device_map is not None
|
device_map is not None
|
||||||
@@ -848,6 +848,13 @@ def _load_state_dict_into_meta_model(
|
|||||||
|
|
||||||
is_quantized = hf_quantizer is not None
|
is_quantized = hf_quantizer is not None
|
||||||
|
|
||||||
|
# get full state dict
|
||||||
|
if is_quantized:
|
||||||
|
if shard_file.endswith(".safetensors"):
|
||||||
|
full_state_dict = load_state_dict(shard_file, map_location="cpu")
|
||||||
|
elif shard_file.endswith(".bin"):
|
||||||
|
full_state_dict = bin_state_dict
|
||||||
|
|
||||||
for serialized_param_name, empty_param in state_dict.items():
|
for serialized_param_name, empty_param in state_dict.items():
|
||||||
# serialized_param_name is the raw, serialized name
|
# serialized_param_name is the raw, serialized name
|
||||||
# fixed_param_name is the model's equivalent
|
# fixed_param_name is the model's equivalent
|
||||||
@@ -912,7 +919,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
model,
|
model,
|
||||||
param,
|
param,
|
||||||
fixed_param_name,
|
fixed_param_name,
|
||||||
state_dict,
|
full_state_dict,
|
||||||
param_device=param_device,
|
param_device=param_device,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
)
|
)
|
||||||
@@ -928,7 +935,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hf_quantizer.create_quantized_param(
|
hf_quantizer.create_quantized_param(
|
||||||
model, param, fixed_param_name, param_device, state_dict, unexpected_keys
|
model, param, fixed_param_name, param_device, full_state_dict, unexpected_keys
|
||||||
)
|
)
|
||||||
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
|
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
|
||||||
# and then cast it to CPU to avoid excessive memory usage on each GPU
|
# and then cast it to CPU to avoid excessive memory usage on each GPU
|
||||||
|
|||||||
Reference in New Issue
Block a user