Keep Quark loading through meta device (#37538)
This commit is contained in:
@@ -725,12 +725,11 @@ def _load_state_dict_into_meta_model(
|
|||||||
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
||||||
|
|
||||||
is_quantized = hf_quantizer is not None
|
is_quantized = hf_quantizer is not None
|
||||||
is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
||||||
QuantizationMethod.HQQ,
|
QuantizationMethod.HQQ,
|
||||||
QuantizationMethod.BITS_AND_BYTES,
|
QuantizationMethod.BITS_AND_BYTES,
|
||||||
QuantizationMethod.QUARK,
|
|
||||||
}
|
}
|
||||||
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_quark
|
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
|
||||||
file_pointer = None
|
file_pointer = None
|
||||||
if is_meta_state_dict:
|
if is_meta_state_dict:
|
||||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||||
@@ -4701,10 +4700,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
QuantizationMethod.HQQ,
|
QuantizationMethod.HQQ,
|
||||||
QuantizationMethod.QUARK,
|
QuantizationMethod.QUARK,
|
||||||
}
|
}
|
||||||
is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
||||||
QuantizationMethod.HQQ,
|
QuantizationMethod.HQQ,
|
||||||
QuantizationMethod.BITS_AND_BYTES,
|
QuantizationMethod.BITS_AND_BYTES,
|
||||||
QuantizationMethod.QUARK,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get all the keys of the state dicts that we have to initialize the model
|
# Get all the keys of the state dicts that we have to initialize the model
|
||||||
@@ -4881,7 +4879,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
map_location = "cpu"
|
map_location = "cpu"
|
||||||
if (
|
if (
|
||||||
shard_file.endswith(".safetensors")
|
shard_file.endswith(".safetensors")
|
||||||
and not is_hqq_or_bnb_or_quark
|
and not is_hqq_or_bnb
|
||||||
and not (is_deepspeed_zero3_enabled() and not is_quantized)
|
and not (is_deepspeed_zero3_enabled() and not is_quantized)
|
||||||
):
|
):
|
||||||
map_location = "meta"
|
map_location = "meta"
|
||||||
|
|||||||
Reference in New Issue
Block a user