@@ -5870,7 +5870,8 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict):
|
|||||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||||
for device, byte_count in total_byte_count.items():
|
for device, byte_count in total_byte_count.items():
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
device_memory = torch.cuda.mem_get_info(device)[0]
|
index = device.index if device.index is not None else torch.cuda.current_device()
|
||||||
|
device_memory = torch.cuda.mem_get_info(index)[0]
|
||||||
# Allow up to 95% of max device memory
|
# Allow up to 95% of max device memory
|
||||||
byte_count = min(byte_count, int(0.95 * device_memory))
|
byte_count = min(byte_count, int(0.95 * device_memory))
|
||||||
# Allocate memory
|
# Allocate memory
|
||||||
|
|||||||
Reference in New Issue
Block a user