[Regression] Fix Quark quantized model loading after refactorization (#37407)
This commit is contained in:
@@ -649,7 +649,10 @@ def _infer_parameter_dtype(
|
|||||||
try:
|
try:
|
||||||
old_param = model.get_parameter_or_buffer(param_name)
|
old_param = model.get_parameter_or_buffer(param_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ:
|
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {
|
||||||
|
QuantizationMethod.HQQ,
|
||||||
|
QuantizationMethod.QUARK,
|
||||||
|
}:
|
||||||
return True, None
|
return True, None
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
@@ -708,11 +711,12 @@ def _load_state_dict_into_meta_model(
|
|||||||
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
||||||
|
|
||||||
is_quantized = hf_quantizer is not None
|
is_quantized = hf_quantizer is not None
|
||||||
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
||||||
QuantizationMethod.HQQ,
|
QuantizationMethod.HQQ,
|
||||||
QuantizationMethod.BITS_AND_BYTES,
|
QuantizationMethod.BITS_AND_BYTES,
|
||||||
]
|
QuantizationMethod.QUARK,
|
||||||
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
|
}
|
||||||
|
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_quark
|
||||||
file_pointer = None
|
file_pointer = None
|
||||||
if is_meta_state_dict:
|
if is_meta_state_dict:
|
||||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||||
@@ -4632,11 +4636,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
):
|
):
|
||||||
# Useful flags
|
# Useful flags
|
||||||
is_quantized = hf_quantizer is not None
|
is_quantized = hf_quantizer is not None
|
||||||
is_hqq = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
|
is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
||||||
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
QuantizationMethod.HQQ,
|
||||||
|
QuantizationMethod.QUARK,
|
||||||
|
}
|
||||||
|
is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
||||||
QuantizationMethod.HQQ,
|
QuantizationMethod.HQQ,
|
||||||
QuantizationMethod.BITS_AND_BYTES,
|
QuantizationMethod.BITS_AND_BYTES,
|
||||||
]
|
QuantizationMethod.QUARK,
|
||||||
|
}
|
||||||
|
|
||||||
# Get all the keys of the state dicts that we have to initialize the model
|
# Get all the keys of the state dicts that we have to initialize the model
|
||||||
if sharded_metadata is not None:
|
if sharded_metadata is not None:
|
||||||
@@ -4798,7 +4806,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
||||||
|
|
||||||
# Warmup cuda to load the weights much faster on devices
|
# Warmup cuda to load the weights much faster on devices
|
||||||
if device_map is not None and not is_hqq:
|
if device_map is not None and not is_hqq_or_quark:
|
||||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||||
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||||
|
|
||||||
@@ -4812,7 +4820,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
map_location = "cpu"
|
map_location = "cpu"
|
||||||
if (
|
if (
|
||||||
shard_file.endswith(".safetensors")
|
shard_file.endswith(".safetensors")
|
||||||
and not is_hqq_or_bnb
|
and not is_hqq_or_bnb_or_quark
|
||||||
and not (is_deepspeed_zero3_enabled() and not is_quantized)
|
and not (is_deepspeed_zero3_enabled() and not is_quantized)
|
||||||
):
|
):
|
||||||
map_location = "meta"
|
map_location = "meta"
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class QuarkTest(unittest.TestCase):
|
|||||||
EXPECTED_OUTPUTS.add("Today I am in Paris and I am not in Paris, France\nToday I am in Paris, Illinois")
|
EXPECTED_OUTPUTS.add("Today I am in Paris and I am not in Paris, France\nToday I am in Paris, Illinois")
|
||||||
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying the city of light. I am not just any ordinary Paris")
|
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying the city of light. I am not just any ordinary Paris")
|
||||||
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying my day off! The sun is shining, the birds are")
|
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying my day off! The sun is shining, the birds are")
|
||||||
|
EXPECTED_OUTPUTS.add("Today I am in Paris and I'm here to tell you about it. It's a beautiful day,")
|
||||||
|
|
||||||
EXPECTED_RELATIVE_DIFFERENCE = 1.66
|
EXPECTED_RELATIVE_DIFFERENCE = 1.66
|
||||||
device_map = None
|
device_map = None
|
||||||
|
|||||||
Reference in New Issue
Block a user