Keep Quark loading through meta device (#37538)

This commit is contained in:
Bowen Bao
2025-04-16 05:19:56 -07:00
committed by GitHub
parent 61436a9323
commit e3d3b54638

View File

@@ -725,12 +725,11 @@ def _load_state_dict_into_meta_model(
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
is_quantized = hf_quantizer is not None
is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
QuantizationMethod.QUARK,
}
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_quark
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
file_pointer = None
if is_meta_state_dict:
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
@@ -4701,10 +4700,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
}
is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
QuantizationMethod.QUARK,
}
# Get all the keys of the state dicts that we have to initialize the model
@@ -4881,7 +4879,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
map_location = "cpu"
if (
shard_file.endswith(".safetensors")
and not is_hqq_or_bnb_or_quark
and not is_hqq_or_bnb
and not (is_deepspeed_zero3_enabled() and not is_quantized)
):
map_location = "meta"