From 80b4c5dcc9ec3dae38a80ff426dd60f5658ab4bc Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Tue, 25 Mar 2025 11:51:41 +0100 Subject: [PATCH] Fix cuda index issue in cache allocator (#36937) fix --- src/transformers/modeling_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ec3b37404d..8438865a0b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5870,7 +5870,8 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict): # This will kick off the caching allocator to avoid having to Malloc afterwards for device, byte_count in total_byte_count.items(): if device.type == "cuda": - device_memory = torch.cuda.mem_get_info(device)[0] + index = device.index if device.index is not None else torch.cuda.current_device() + device_memory = torch.cuda.mem_get_info(index)[0] # Allow up to 95% of max device memory byte_count = min(byte_count, int(0.95 * device_memory)) # Allocate memory