[core] reuse unused reserved cuda memory when loading models (#37920)
This commit is contained in:
@@ -1285,6 +1285,13 @@ def _get_device_map(
|
|||||||
max_memory = get_max_memory(max_memory)
|
max_memory = get_max_memory(max_memory)
|
||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
max_memory = hf_quantizer.adjust_max_memory(max_memory)
|
max_memory = hf_quantizer.adjust_max_memory(max_memory)
|
||||||
|
|
||||||
|
# CUDA: `max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU, which we
|
||||||
|
# can use to allocate parameters.
|
||||||
|
for device_name in max_memory.keys():
|
||||||
|
if isinstance(device_name, int): # it's a GPU device
|
||||||
|
unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name)
|
||||||
|
max_memory[device_name] += unused_memory
|
||||||
device_map_kwargs["max_memory"] = max_memory
|
device_map_kwargs["max_memory"] = max_memory
|
||||||
|
|
||||||
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
|
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
|
||||||
@@ -5979,6 +5986,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
|||||||
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
|
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
|
||||||
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
|
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
|
||||||
byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
|
byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
|
||||||
|
# If there is *unused* reserved cuda memory, we can skip/reduce the allocation.
|
||||||
|
unused_memory = torch.cuda.memory_reserved(index) - torch.cuda.memory_allocated(index)
|
||||||
|
byte_count = max(0, byte_count - unused_memory)
|
||||||
# Allocate memory
|
# Allocate memory
|
||||||
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
|
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user