From a9384f849a2bfc812ae145b8b4d1f5d4afdf0071 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 7 May 2025 09:49:31 +0100 Subject: [PATCH] [offload] respect `max_memory` argument when factoring in unused reserved memory (#37982) --- src/transformers/modeling_utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f104d7c45f..be84fdee54 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1275,7 +1275,7 @@ def _get_device_map( ) if device_map != "sequential": - max_memory = get_balanced_memory( + inferred_max_memory = get_balanced_memory( model, dtype=target_dtype, low_zero=(device_map == "balanced_low_0"), @@ -1283,20 +1283,23 @@ def _get_device_map( **device_map_kwargs, ) else: - max_memory = get_max_memory(max_memory) + inferred_max_memory = get_max_memory(max_memory) if hf_quantizer is not None: - max_memory = hf_quantizer.adjust_max_memory(max_memory) + inferred_max_memory = hf_quantizer.adjust_max_memory(inferred_max_memory) - # `max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU, which we - # can use to allocate parameters. - for device_name in max_memory.keys(): + # `inferred_max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU, + # which we can use to allocate parameters. + for device_name in inferred_max_memory.keys(): if isinstance(device_name, int): # it's a GPU device if is_torch_xpu_available(): unused_memory = torch.xpu.memory_reserved(device_name) - torch.xpu.memory_allocated(device_name) else: unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name) - max_memory[device_name] += unused_memory - device_map_kwargs["max_memory"] = max_memory + inferred_max_memory[device_name] += unused_memory + # respect the `max_memory` passed by the user + if max_memory is not None and device_name in max_memory: + inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name]) + device_map_kwargs["max_memory"] = inferred_max_memory device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)