Loading optimizations (#36742)
* improvements * Update modeling_utils.py * add some doc about loading * Update modeling_utils.py
This commit is contained in:
@@ -4824,11 +4824,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None and hf_quantizer is None:
|
||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map, dtype)
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map)
|
||||
|
||||
error_msgs = []
|
||||
mismatched_keys = []
|
||||
has_multiple_shards = len(checkpoint_files) > 1
|
||||
# Iterate on all the shards to load the weights
|
||||
for shard_file in checkpoint_files:
|
||||
# Skip the load for shards that only contain disk-offloaded weights
|
||||
@@ -4865,7 +4864,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
prefix if loading_base_model_from_task_state_dict else "",
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage and shard_file is not None:
|
||||
if low_cpu_mem_usage:
|
||||
# Skip it with fsdp on ranks other than 0
|
||||
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
||||
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
||||
@@ -4893,10 +4892,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
else:
|
||||
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params)
|
||||
|
||||
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
|
||||
del state_dict
|
||||
# force memory release if loading multiple shards
|
||||
if has_multiple_shards:
|
||||
gc.collect()
|
||||
|
||||
# Adjust offloaded weights name and save if needed
|
||||
if disk_offload_index is not None and len(disk_offload_index) > 0:
|
||||
@@ -5789,11 +5786,24 @@ def expand_device_map(device_map, param_names):
|
||||
return new_device_map
|
||||
|
||||
|
||||
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict:
|
||||
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict):
|
||||
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
||||
the model, which is actually the loading speed botteneck.
|
||||
Calling this function allows to cut the model loading time by a very large margin.
|
||||
|
||||
A few facts related to loading speed (taking into account the use of this function):
|
||||
- When loading a model the first time, it is usually slower than the subsequent times, because the OS is very likely
|
||||
to cache the different state dicts (if enough ressources/RAM are available)
|
||||
- Trying to force the OS to cache the files in advance (by e.g. accessing a small portion of them) is really hard,
|
||||
and not a good idea in general as this is low level OS optimizations that depend on ressource usage anyway
|
||||
- As of 18/03/2025, loading a Llama 70B model with TP takes ~1 min without file cache, and ~13s with full file cache.
|
||||
The baseline, i.e. only loading the tensor shards on device and adjusting dtype (i.e. copying them) is ~5s with full cache.
|
||||
These numbers are reported for TP on 4 H100 GPUs.
|
||||
- It is useless to pre-allocate more than the model size in this function (i.e. using an `allocation_factor` > 1) as
|
||||
cudaMalloc is not a bottleneck at all anymore
|
||||
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
|
||||
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
|
||||
"""
|
||||
# Remove disk and cpu devices, and cast to proper torch.device
|
||||
accelerator_device_map = {
|
||||
@@ -5808,31 +5818,26 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
||||
else None
|
||||
)
|
||||
|
||||
parameter_count = defaultdict(lambda: 0)
|
||||
allocation_factor = 1
|
||||
if torch.distributed.is_initialized() or len(set(accelerator_device_map.values())) >= 2:
|
||||
allocation_factor = 2
|
||||
|
||||
total_byte_count = defaultdict(lambda: 0)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
param = model.get_parameter_or_buffer(param_name)
|
||||
param_size = int(math.prod(param.shape) * allocation_factor)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = math.prod(param.shape) * dtype_byte_size(param.dtype)
|
||||
|
||||
if tp_plan_regex is not None:
|
||||
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
|
||||
param_size //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
|
||||
param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
|
||||
|
||||
parameter_count[device] += param_size
|
||||
|
||||
dtype = dtype if dtype is not None else torch.float32
|
||||
total_byte_count[device] += param_byte_count
|
||||
|
||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||
for device, param_count in parameter_count.items():
|
||||
max_memory_device = None
|
||||
for device, byte_count in total_byte_count.items():
|
||||
if device.type == "cuda":
|
||||
max_memory_device = torch.cuda.mem_get_info(device.index)[0]
|
||||
# allocate only if we have enough memory
|
||||
if max_memory_device is None or max_memory_device > param_count * dtype_byte_size(dtype):
|
||||
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
|
||||
device_memory = torch.cuda.mem_get_info(device)[0]
|
||||
# Allow up to 95% of max device memory
|
||||
byte_count = min(byte_count, int(0.95 * device_memory))
|
||||
# Allocate memory
|
||||
_ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
|
||||
|
||||
|
||||
def get_disk_only_shard_files(device_map, weight_map):
|
||||
|
||||
Reference in New Issue
Block a user