From 96089086395271e1e6ced793a54a9ff308f71432 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 22 Apr 2025 13:12:49 +0200 Subject: [PATCH] Correct warm-up with fp8 (#37670) * start clean warmup for quantizers * style --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/modeling_utils.py | 6 ++++-- src/transformers/quantizers/base.py | 11 +++++++++++ .../quantizers/quantizer_finegrained_fp8.py | 4 ++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 50a200ae76..0dbc97781e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4866,7 +4866,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # Warmup cuda to load the weights much faster on devices if device_map is not None and not is_hqq_or_quark: expanded_device_map = expand_device_map(device_map, expected_keys) - caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4) + caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer) error_msgs = [] # Iterate on all the shards to load the weights @@ -5871,7 +5871,7 @@ def is_accelerator_device(device: Union[str, int, torch.device]) -> bool: return torch.device(device).type not in ["meta", "cpu"] -def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2): +def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, hf_quantizer: Optional[HfQuantizer]): """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model, which is actually the loading speed botteneck. @@ -5890,6 +5890,8 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, - Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices. However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end. """ + factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor() + # Remove disk, cpu and meta devices, and cast to proper torch.device accelerator_device_map = { param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device) diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index e075af9618..d5ae46a0af 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -252,6 +252,17 @@ class HfQuantizer(ABC): return model + def get_cuda_warm_up_factor(self): + """ + The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up cuda. + A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means + we allocate half the memory of the weights residing in the empty model, etc... + """ + # By default we return 4, i.e. half the model size (this corresponds to the case where the model is not + # really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual + # weight loading) + return 4 + def _dequantize(self, model): raise NotImplementedError( f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub." diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index 1895eab648..76f6f9221c 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -200,3 +200,7 @@ class FineGrainedFP8HfQuantizer(HfQuantizer): @property def is_trainable(self) -> bool: return False + + def get_cuda_warm_up_factor(self): + # Pre-processing is done cleanly, so we can allocate everything here + return 2