Fix meta state dict loading with quantizers (#37136)
Update modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -752,7 +752,11 @@ def _load_state_dict_into_meta_model(
|
|||||||
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
||||||
|
|
||||||
is_quantized = hf_quantizer is not None
|
is_quantized = hf_quantizer is not None
|
||||||
is_meta_state_dict = shard_file.endswith(".safetensors")
|
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
||||||
|
QuantizationMethod.HQQ,
|
||||||
|
QuantizationMethod.BITS_AND_BYTES,
|
||||||
|
]
|
||||||
|
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
|
||||||
file_pointer = None
|
file_pointer = None
|
||||||
if is_meta_state_dict:
|
if is_meta_state_dict:
|
||||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||||
@@ -4828,6 +4832,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||||
|
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
|
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
||||||
|
QuantizationMethod.HQQ,
|
||||||
|
QuantizationMethod.BITS_AND_BYTES,
|
||||||
|
]
|
||||||
# Iterate on all the shards to load the weights
|
# Iterate on all the shards to load the weights
|
||||||
for shard_file in checkpoint_files:
|
for shard_file in checkpoint_files:
|
||||||
# Skip the load for shards that only contain disk-offloaded weights
|
# Skip the load for shards that only contain disk-offloaded weights
|
||||||
@@ -4835,7 +4843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
map_location = "cpu"
|
map_location = "cpu"
|
||||||
if shard_file.endswith(".safetensors"):
|
if shard_file.endswith(".safetensors") and not is_hqq_or_bnb:
|
||||||
map_location = "meta"
|
map_location = "meta"
|
||||||
elif (
|
elif (
|
||||||
device_map is not None
|
device_map is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user