Remove HQQ from caching allocator warmup (#37347)
Update modeling_utils.py
This commit is contained in:
@@ -4612,6 +4612,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
):
|
||||
# Useful flags
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_hqq = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
|
||||
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
@@ -4777,7 +4778,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
||||
|
||||
# Warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None:
|
||||
if device_map is not None and not is_hqq:
|
||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user