Fix meta state dict loading with quantizers (#37136)
Update modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -752,7 +752,11 @@ def _load_state_dict_into_meta_model(
|
||||
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
||||
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_meta_state_dict = shard_file.endswith(".safetensors")
|
||||
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
]
|
||||
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
|
||||
file_pointer = None
|
||||
if is_meta_state_dict:
|
||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||
@@ -4828,6 +4832,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||
|
||||
error_msgs = []
|
||||
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
]
|
||||
# Iterate on all the shards to load the weights
|
||||
for shard_file in checkpoint_files:
|
||||
# Skip the load for shards that only contain disk-offloaded weights
|
||||
@@ -4835,7 +4843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
continue
|
||||
|
||||
map_location = "cpu"
|
||||
if shard_file.endswith(".safetensors"):
|
||||
if shard_file.endswith(".safetensors") and not is_hqq_or_bnb:
|
||||
map_location = "meta"
|
||||
elif (
|
||||
device_map is not None
|
||||
|
||||
Reference in New Issue
Block a user