Remove HQQ from caching allocator warmup (#37347)
Update modeling_utils.py
This commit is contained in:
@@ -4612,6 +4612,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
):
|
):
|
||||||
# Useful flags
|
# Useful flags
|
||||||
is_quantized = hf_quantizer is not None
|
is_quantized = hf_quantizer is not None
|
||||||
|
is_hqq = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
|
||||||
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
||||||
QuantizationMethod.HQQ,
|
QuantizationMethod.HQQ,
|
||||||
QuantizationMethod.BITS_AND_BYTES,
|
QuantizationMethod.BITS_AND_BYTES,
|
||||||
@@ -4777,7 +4778,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
||||||
|
|
||||||
# Warmup cuda to load the weights much faster on devices
|
# Warmup cuda to load the weights much faster on devices
|
||||||
if device_map is not None:
|
if device_map is not None and not is_hqq:
|
||||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||||
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user