Loading optimizations (#36742)

* improvements

* Update modeling_utils.py

* add some doc about loading

* Update modeling_utils.py
This commit is contained in:
Cyril Vallez
2025-03-18 16:38:44 +01:00
committed by GitHub
parent 7baf00089a
commit db1d4c5a0b

View File

@@ -4824,11 +4824,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Warmup cuda to load the weights much faster on devices # Warmup cuda to load the weights much faster on devices
if device_map is not None and hf_quantizer is None: if device_map is not None and hf_quantizer is None:
expanded_device_map = expand_device_map(device_map, expected_keys) expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model_to_load, expanded_device_map, dtype) caching_allocator_warmup(model_to_load, expanded_device_map)
error_msgs = [] error_msgs = []
mismatched_keys = [] mismatched_keys = []
has_multiple_shards = len(checkpoint_files) > 1
# Iterate on all the shards to load the weights # Iterate on all the shards to load the weights
for shard_file in checkpoint_files: for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights # Skip the load for shards that only contain disk-offloaded weights
@@ -4865,7 +4864,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
prefix if loading_base_model_from_task_state_dict else "", prefix if loading_base_model_from_task_state_dict else "",
) )
if low_cpu_mem_usage and shard_file is not None: if low_cpu_mem_usage:
# Skip it with fsdp on ranks other than 0 # Skip it with fsdp on ranks other than 0
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
@@ -4893,10 +4892,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params)
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
del state_dict del state_dict
# force memory release if loading multiple shards
if has_multiple_shards:
gc.collect()
# Adjust offloaded weights name and save if needed # Adjust offloaded weights name and save if needed
if disk_offload_index is not None and len(disk_offload_index) > 0: if disk_offload_index is not None and len(disk_offload_index) > 0:
@@ -5789,11 +5786,24 @@ def expand_device_map(device_map, param_names):
return new_device_map return new_device_map
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict: def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict):
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
the model, which is actually the loading speed botteneck. the model, which is actually the loading speed botteneck.
Calling this function allows to cut the model loading time by a very large margin. Calling this function allows to cut the model loading time by a very large margin.
A few facts related to loading speed (taking into account the use of this function):
- When loading a model the first time, it is usually slower than the subsequent times, because the OS is very likely
to cache the different state dicts (if enough ressources/RAM are available)
- Trying to force the OS to cache the files in advance (by e.g. accessing a small portion of them) is really hard,
and not a good idea in general as this is low level OS optimizations that depend on ressource usage anyway
- As of 18/03/2025, loading a Llama 70B model with TP takes ~1 min without file cache, and ~13s with full file cache.
The baseline, i.e. only loading the tensor shards on device and adjusting dtype (i.e. copying them) is ~5s with full cache.
These numbers are reported for TP on 4 H100 GPUs.
- It is useless to pre-allocate more than the model size in this function (i.e. using an `allocation_factor` > 1) as
cudaMalloc is not a bottleneck at all anymore
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
""" """
# Remove disk and cpu devices, and cast to proper torch.device # Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = { accelerator_device_map = {
@@ -5808,31 +5818,26 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
else None else None
) )
parameter_count = defaultdict(lambda: 0) total_byte_count = defaultdict(lambda: 0)
allocation_factor = 1
if torch.distributed.is_initialized() or len(set(accelerator_device_map.values())) >= 2:
allocation_factor = 2
for param_name, device in accelerator_device_map.items(): for param_name, device in accelerator_device_map.items():
param = model.get_parameter_or_buffer(param_name) param = model.get_parameter_or_buffer(param_name)
param_size = int(math.prod(param.shape) * allocation_factor) # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = math.prod(param.shape) * dtype_byte_size(param.dtype)
if tp_plan_regex is not None: if tp_plan_regex is not None:
generic_name = re.sub(r"\.\d+\.", ".*.", param_name) generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
param_size //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1 param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
parameter_count[device] += param_size total_byte_count[device] += param_byte_count
dtype = dtype if dtype is not None else torch.float32
# This will kick off the caching allocator to avoid having to Malloc afterwards # This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items(): for device, byte_count in total_byte_count.items():
max_memory_device = None
if device.type == "cuda": if device.type == "cuda":
max_memory_device = torch.cuda.mem_get_info(device.index)[0] device_memory = torch.cuda.mem_get_info(device)[0]
# allocate only if we have enough memory # Allow up to 95% of max device memory
if max_memory_device is None or max_memory_device > param_count * dtype_byte_size(dtype): byte_count = min(byte_count, int(0.95 * device_memory))
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False) # Allocate memory
_ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
def get_disk_only_shard_files(device_map, weight_map): def get_disk_only_shard_files(device_map, weight_map):