Fix meta state dict loading with quantizers (#37136)

Update modeling_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Cyril Vallez
2025-04-01 18:45:58 +02:00
committed by GitHub
parent 35253076f4
commit bc2dea3f54

View File

@@ -752,7 +752,11 @@ def _load_state_dict_into_meta_model(
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)]) device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
is_quantized = hf_quantizer is not None is_quantized = hf_quantizer is not None
is_meta_state_dict = shard_file.endswith(".safetensors") is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
]
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
file_pointer = None file_pointer = None
if is_meta_state_dict: if is_meta_state_dict:
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
@@ -4828,6 +4832,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4) caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
error_msgs = [] error_msgs = []
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
]
# Iterate on all the shards to load the weights # Iterate on all the shards to load the weights
for shard_file in checkpoint_files: for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights # Skip the load for shards that only contain disk-offloaded weights
@@ -4835,7 +4843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
continue continue
map_location = "cpu" map_location = "cpu"
if shard_file.endswith(".safetensors"): if shard_file.endswith(".safetensors") and not is_hqq_or_bnb:
map_location = "meta" map_location = "meta"
elif ( elif (
device_map is not None device_map is not None