@@ -114,6 +114,7 @@ from .utils import (
|
|||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
is_torch_sdpa_available,
|
is_torch_sdpa_available,
|
||||||
is_torch_xla_available,
|
is_torch_xla_available,
|
||||||
|
is_torch_xpu_available,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
strtobool,
|
strtobool,
|
||||||
@@ -1286,11 +1287,14 @@ def _get_device_map(
|
|||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
max_memory = hf_quantizer.adjust_max_memory(max_memory)
|
max_memory = hf_quantizer.adjust_max_memory(max_memory)
|
||||||
|
|
||||||
# CUDA: `max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU, which we
|
# `max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU, which we
|
||||||
# can use to allocate parameters.
|
# can use to allocate parameters.
|
||||||
for device_name in max_memory.keys():
|
for device_name in max_memory.keys():
|
||||||
if isinstance(device_name, int): # it's a GPU device
|
if isinstance(device_name, int): # it's a GPU device
|
||||||
unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name)
|
if is_torch_xpu_available():
|
||||||
|
unused_memory = torch.xpu.memory_reserved(device_name) - torch.xpu.memory_allocated(device_name)
|
||||||
|
else:
|
||||||
|
unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name)
|
||||||
max_memory[device_name] += unused_memory
|
max_memory[device_name] += unused_memory
|
||||||
device_map_kwargs["max_memory"] = max_memory
|
device_map_kwargs["max_memory"] = max_memory
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user