More appropriate cuda warmup in resource-constrained hardware (#37550)

* better allocation in resource constrained env

* Update modeling_utils.py

* CIs
This commit is contained in:
Cyril Vallez
2025-04-16 13:40:02 +02:00
committed by GitHub
parent 6fd87d1172
commit 7dafcd0077

View File

@@ -21,7 +21,6 @@ import importlib.metadata
import inspect
import itertools
import json
import math
import os
import re
import shutil
@@ -5872,7 +5871,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
for param_name, device in accelerator_device_map.items():
param = model.get_parameter_or_buffer(param_name)
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = math.prod(param.shape) * param.element_size()
param_byte_count = param.numel() * param.element_size()
if tp_plan_regex is not None:
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
@@ -5885,8 +5884,14 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
if device.type == "cuda":
index = device.index if device.index is not None else torch.cuda.current_device()
device_memory = torch.cuda.mem_get_info(index)[0]
# Allow up to 95% of max device memory
byte_count = min(byte_count, int(0.95 * device_memory))
# Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
# than that amount might sometimes lead to unecesary cuda OOM, if the last parameter to be loaded on the device is large,
# and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
# the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
# to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
byte_count = min(byte_count, int(device_memory - 1.2 * 1024**3))
# Allocate memory
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)