From bc2dea3f549d512503640a7f244d45490d918378 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 1 Apr 2025 18:45:58 +0200 Subject: [PATCH] Fix meta state dict loading with quantizers (#37136) Update modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/modeling_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0d1126ba54..352c86a13d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -752,7 +752,11 @@ def _load_state_dict_into_meta_model( device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)]) is_quantized = hf_quantizer is not None - is_meta_state_dict = shard_file.endswith(".safetensors") + is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [ + QuantizationMethod.HQQ, + QuantizationMethod.BITS_AND_BYTES, + ] + is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb file_pointer = None if is_meta_state_dict: file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) @@ -4828,6 +4832,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4) error_msgs = [] + is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [ + QuantizationMethod.HQQ, + QuantizationMethod.BITS_AND_BYTES, + ] # Iterate on all the shards to load the weights for shard_file in checkpoint_files: # Skip the load for shards that only contain disk-offloaded weights @@ -4835,7 +4843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix continue map_location = "cpu" - if shard_file.endswith(".safetensors"): + if shard_file.endswith(".safetensors") and not is_hqq_or_bnb: map_location = "meta" elif ( device_map is not None