Correct warm-up with fp8 (#37670)

* start clean warmup for quantizers

* style

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Cyril Vallez
2025-04-22 13:12:49 +02:00
committed by GitHub
parent 6614209b96
commit 9608908639
3 changed files with 19 additions and 2 deletions

View File

@@ -4866,7 +4866,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Warmup cuda to load the weights much faster on devices
if device_map is not None and not is_hqq_or_quark:
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer)
error_msgs = []
# Iterate on all the shards to load the weights
@@ -5871,7 +5871,7 @@ def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
return torch.device(device).type not in ["meta", "cpu"]
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2):
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, hf_quantizer: Optional[HfQuantizer]):
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
the model, which is actually the loading speed botteneck.
@@ -5890,6 +5890,8 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
"""
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
# Remove disk, cpu and meta devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)