Remove HQQ from caching allocator warmup (#37347)

Update modeling_utils.py
This commit is contained in:
Cyril Vallez
2025-04-07 18:33:48 +02:00
committed by GitHub
parent 832cb684a0
commit 48e179857c

View File

@@ -4612,6 +4612,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
):
# Useful flags
is_quantized = hf_quantizer is not None
is_hqq = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
@@ -4777,7 +4778,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
# Warmup cuda to load the weights much faster on devices
if device_map is not None:
if device_map is not None and not is_hqq:
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)