[Regression] Fix Quark quantized model loading after refactorization (#37407)

This commit is contained in:
Bowen Bao
2025-04-11 04:43:36 -07:00
committed by GitHub
parent a563999a02
commit 6cef03ba66
2 changed files with 18 additions and 9 deletions

View File

@@ -649,7 +649,10 @@ def _infer_parameter_dtype(
try: try:
old_param = model.get_parameter_or_buffer(param_name) old_param = model.get_parameter_or_buffer(param_name)
except Exception as e: except Exception as e:
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ: if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
}:
return True, None return True, None
else: else:
raise e raise e
@@ -708,11 +711,12 @@ def _load_state_dict_into_meta_model(
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)]) device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
is_quantized = hf_quantizer is not None is_quantized = hf_quantizer is not None
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [ is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ, QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES, QuantizationMethod.BITS_AND_BYTES,
] QuantizationMethod.QUARK,
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb }
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_quark
file_pointer = None file_pointer = None
if is_meta_state_dict: if is_meta_state_dict:
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
@@ -4632,11 +4636,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
): ):
# Useful flags # Useful flags
is_quantized = hf_quantizer is not None is_quantized = hf_quantizer is not None
is_hqq = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [ QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
}
is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ, QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES, QuantizationMethod.BITS_AND_BYTES,
] QuantizationMethod.QUARK,
}
# Get all the keys of the state dicts that we have to initialize the model # Get all the keys of the state dicts that we have to initialize the model
if sharded_metadata is not None: if sharded_metadata is not None:
@@ -4798,7 +4806,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys) expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
# Warmup cuda to load the weights much faster on devices # Warmup cuda to load the weights much faster on devices
if device_map is not None and not is_hqq: if device_map is not None and not is_hqq_or_quark:
expanded_device_map = expand_device_map(device_map, expected_keys) expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4) caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
@@ -4812,7 +4820,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
map_location = "cpu" map_location = "cpu"
if ( if (
shard_file.endswith(".safetensors") shard_file.endswith(".safetensors")
and not is_hqq_or_bnb and not is_hqq_or_bnb_or_quark
and not (is_deepspeed_zero3_enabled() and not is_quantized) and not (is_deepspeed_zero3_enabled() and not is_quantized)
): ):
map_location = "meta" map_location = "meta"

View File

@@ -53,6 +53,7 @@ class QuarkTest(unittest.TestCase):
EXPECTED_OUTPUTS.add("Today I am in Paris and I am not in Paris, France\nToday I am in Paris, Illinois") EXPECTED_OUTPUTS.add("Today I am in Paris and I am not in Paris, France\nToday I am in Paris, Illinois")
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying the city of light. I am not just any ordinary Paris") EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying the city of light. I am not just any ordinary Paris")
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying my day off! The sun is shining, the birds are") EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying my day off! The sun is shining, the birds are")
EXPECTED_OUTPUTS.add("Today I am in Paris and I'm here to tell you about it. It's a beautiful day,")
EXPECTED_RELATIVE_DIFFERENCE = 1.66 EXPECTED_RELATIVE_DIFFERENCE = 1.66
device_map = None device_map = None