More appropriate cuda warmup in resource-constrained hardware (#37550)
* better allocation in resource constrained env * Update modeling_utils.py * CIs
This commit is contained in:
@@ -21,7 +21,6 @@ import importlib.metadata
|
||||
import inspect
|
||||
import itertools
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -5872,7 +5871,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
param = model.get_parameter_or_buffer(param_name)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = math.prod(param.shape) * param.element_size()
|
||||
param_byte_count = param.numel() * param.element_size()
|
||||
|
||||
if tp_plan_regex is not None:
|
||||
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
|
||||
@@ -5885,8 +5884,14 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
||||
if device.type == "cuda":
|
||||
index = device.index if device.index is not None else torch.cuda.current_device()
|
||||
device_memory = torch.cuda.mem_get_info(index)[0]
|
||||
# Allow up to 95% of max device memory
|
||||
byte_count = min(byte_count, int(0.95 * device_memory))
|
||||
# Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
|
||||
# than that amount might sometimes lead to unecesary cuda OOM, if the last parameter to be loaded on the device is large,
|
||||
# and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
|
||||
# the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
|
||||
# to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
|
||||
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
|
||||
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
|
||||
byte_count = min(byte_count, int(device_memory - 1.2 * 1024**3))
|
||||
# Allocate memory
|
||||
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user