[offload] respect max_memory argument when factoring in unused reserved memory (#37982)
This commit is contained in:
@@ -1275,7 +1275,7 @@ def _get_device_map(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if device_map != "sequential":
|
if device_map != "sequential":
|
||||||
max_memory = get_balanced_memory(
|
inferred_max_memory = get_balanced_memory(
|
||||||
model,
|
model,
|
||||||
dtype=target_dtype,
|
dtype=target_dtype,
|
||||||
low_zero=(device_map == "balanced_low_0"),
|
low_zero=(device_map == "balanced_low_0"),
|
||||||
@@ -1283,20 +1283,23 @@ def _get_device_map(
|
|||||||
**device_map_kwargs,
|
**device_map_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
max_memory = get_max_memory(max_memory)
|
inferred_max_memory = get_max_memory(max_memory)
|
||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
max_memory = hf_quantizer.adjust_max_memory(max_memory)
|
inferred_max_memory = hf_quantizer.adjust_max_memory(inferred_max_memory)
|
||||||
|
|
||||||
# `max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU, which we
|
# `inferred_max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU,
|
||||||
# can use to allocate parameters.
|
# which we can use to allocate parameters.
|
||||||
for device_name in max_memory.keys():
|
for device_name in inferred_max_memory.keys():
|
||||||
if isinstance(device_name, int): # it's a GPU device
|
if isinstance(device_name, int): # it's a GPU device
|
||||||
if is_torch_xpu_available():
|
if is_torch_xpu_available():
|
||||||
unused_memory = torch.xpu.memory_reserved(device_name) - torch.xpu.memory_allocated(device_name)
|
unused_memory = torch.xpu.memory_reserved(device_name) - torch.xpu.memory_allocated(device_name)
|
||||||
else:
|
else:
|
||||||
unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name)
|
unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name)
|
||||||
max_memory[device_name] += unused_memory
|
inferred_max_memory[device_name] += unused_memory
|
||||||
device_map_kwargs["max_memory"] = max_memory
|
# respect the `max_memory` passed by the user
|
||||||
|
if max_memory is not None and device_name in max_memory:
|
||||||
|
inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name])
|
||||||
|
device_map_kwargs["max_memory"] = inferred_max_memory
|
||||||
|
|
||||||
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
|
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user