Correct warm-up with fp8 (#37670)
* start clean warmup for quantizers * style --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -4866,7 +4866,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None and not is_hqq_or_quark:
|
||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer)
|
||||
|
||||
error_msgs = []
|
||||
# Iterate on all the shards to load the weights
|
||||
@@ -5871,7 +5871,7 @@ def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
|
||||
return torch.device(device).type not in ["meta", "cpu"]
|
||||
|
||||
|
||||
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2):
|
||||
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, hf_quantizer: Optional[HfQuantizer]):
|
||||
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
||||
the model, which is actually the loading speed botteneck.
|
||||
@@ -5890,6 +5890,8 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
||||
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
|
||||
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
|
||||
"""
|
||||
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
|
||||
|
||||
# Remove disk, cpu and meta devices, and cast to proper torch.device
|
||||
accelerator_device_map = {
|
||||
param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
|
||||
|
||||
Reference in New Issue
Block a user