From 6cef03ba660a0fe35d8cbcf00195410ae4c7557b Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Fri, 11 Apr 2025 04:43:36 -0700 Subject: [PATCH] [Regression] Fix Quark quantized model loading after refactorization (#37407) --- src/transformers/modeling_utils.py | 26 ++++++++++++------- .../quark_integration/test_quark.py | 1 + 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 226ca6ee61..a73aecfb88 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -649,7 +649,10 @@ def _infer_parameter_dtype( try: old_param = model.get_parameter_or_buffer(param_name) except Exception as e: - if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ: + if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in { + QuantizationMethod.HQQ, + QuantizationMethod.QUARK, + }: return True, None else: raise e @@ -708,11 +711,12 @@ def _load_state_dict_into_meta_model( device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)]) is_quantized = hf_quantizer is not None - is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [ + is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in { QuantizationMethod.HQQ, QuantizationMethod.BITS_AND_BYTES, - ] - is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb + QuantizationMethod.QUARK, + } + is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_quark file_pointer = None if is_meta_state_dict: file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) @@ -4632,11 +4636,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi ): # Useful flags is_quantized = hf_quantizer is not None - is_hqq = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ - is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [ + is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in { + QuantizationMethod.HQQ, + QuantizationMethod.QUARK, + } + is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in { QuantizationMethod.HQQ, QuantizationMethod.BITS_AND_BYTES, - ] + QuantizationMethod.QUARK, + } # Get all the keys of the state dicts that we have to initialize the model if sharded_metadata is not None: @@ -4798,7 +4806,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys) # Warmup cuda to load the weights much faster on devices - if device_map is not None and not is_hqq: + if device_map is not None and not is_hqq_or_quark: expanded_device_map = expand_device_map(device_map, expected_keys) caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4) @@ -4812,7 +4820,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi map_location = "cpu" if ( shard_file.endswith(".safetensors") - and not is_hqq_or_bnb + and not is_hqq_or_bnb_or_quark and not (is_deepspeed_zero3_enabled() and not is_quantized) ): map_location = "meta" diff --git a/tests/quantization/quark_integration/test_quark.py b/tests/quantization/quark_integration/test_quark.py index 81584fa02e..aefe1ebf44 100644 --- a/tests/quantization/quark_integration/test_quark.py +++ b/tests/quantization/quark_integration/test_quark.py @@ -53,6 +53,7 @@ class QuarkTest(unittest.TestCase): EXPECTED_OUTPUTS.add("Today I am in Paris and I am not in Paris, France\nToday I am in Paris, Illinois") EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying the city of light. I am not just any ordinary Paris") EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying my day off! The sun is shining, the birds are") + EXPECTED_OUTPUTS.add("Today I am in Paris and I'm here to tell you about it. It's a beautiful day,") EXPECTED_RELATIVE_DIFFERENCE = 1.66 device_map = None