From ff5ef95db7f38270eda4b893fb285385b246799a Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 6 May 2025 17:57:49 +0800 Subject: [PATCH] add xpu memory check (#37969) add xpu check --- src/transformers/modeling_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9457814b23..f104d7c45f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -114,6 +114,7 @@ from .utils import ( is_torch_npu_available, is_torch_sdpa_available, is_torch_xla_available, + is_torch_xpu_available, logging, replace_return_docstrings, strtobool, @@ -1286,11 +1287,14 @@ def _get_device_map( if hf_quantizer is not None: max_memory = hf_quantizer.adjust_max_memory(max_memory) - # CUDA: `max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU, which we + # `max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU, which we # can use to allocate parameters. for device_name in max_memory.keys(): if isinstance(device_name, int): # it's a GPU device - unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name) + if is_torch_xpu_available(): + unused_memory = torch.xpu.memory_reserved(device_name) - torch.xpu.memory_allocated(device_name) + else: + unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name) max_memory[device_name] += unused_memory device_map_kwargs["max_memory"] = max_memory