From e3d3b546389f2df2812b8490beb25dddc8b93a03 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Wed, 16 Apr 2025 05:19:56 -0700 Subject: [PATCH] Keep Quark loading through meta device (#37538) --- src/transformers/modeling_utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bd63aad88b..b1ec6896be 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -725,12 +725,11 @@ def _load_state_dict_into_meta_model( device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)]) is_quantized = hf_quantizer is not None - is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in { + is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in { QuantizationMethod.HQQ, QuantizationMethod.BITS_AND_BYTES, - QuantizationMethod.QUARK, } - is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_quark + is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb file_pointer = None if is_meta_state_dict: file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) @@ -4701,10 +4700,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi QuantizationMethod.HQQ, QuantizationMethod.QUARK, } - is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in { + is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in { QuantizationMethod.HQQ, QuantizationMethod.BITS_AND_BYTES, - QuantizationMethod.QUARK, } # Get all the keys of the state dicts that we have to initialize the model @@ -4881,7 +4879,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi map_location = "cpu" if ( shard_file.endswith(".safetensors") - and not is_hqq_or_bnb_or_quark + and not is_hqq_or_bnb and not (is_deepspeed_zero3_enabled() and not is_quantized) ): map_location = "meta"