Loading optimizations (#36742)
* improvements * Update modeling_utils.py * add some doc about loading * Update modeling_utils.py
This commit is contained in:
@@ -4824,11 +4824,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Warmup cuda to load the weights much faster on devices
|
# Warmup cuda to load the weights much faster on devices
|
||||||
if device_map is not None and hf_quantizer is None:
|
if device_map is not None and hf_quantizer is None:
|
||||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||||
caching_allocator_warmup(model_to_load, expanded_device_map, dtype)
|
caching_allocator_warmup(model_to_load, expanded_device_map)
|
||||||
|
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
mismatched_keys = []
|
mismatched_keys = []
|
||||||
has_multiple_shards = len(checkpoint_files) > 1
|
|
||||||
# Iterate on all the shards to load the weights
|
# Iterate on all the shards to load the weights
|
||||||
for shard_file in checkpoint_files:
|
for shard_file in checkpoint_files:
|
||||||
# Skip the load for shards that only contain disk-offloaded weights
|
# Skip the load for shards that only contain disk-offloaded weights
|
||||||
@@ -4865,7 +4864,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
prefix if loading_base_model_from_task_state_dict else "",
|
prefix if loading_base_model_from_task_state_dict else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
if low_cpu_mem_usage and shard_file is not None:
|
if low_cpu_mem_usage:
|
||||||
# Skip it with fsdp on ranks other than 0
|
# Skip it with fsdp on ranks other than 0
|
||||||
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
||||||
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
||||||
@@ -4893,10 +4892,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
else:
|
else:
|
||||||
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params)
|
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params)
|
||||||
|
|
||||||
|
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
|
||||||
del state_dict
|
del state_dict
|
||||||
# force memory release if loading multiple shards
|
|
||||||
if has_multiple_shards:
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
# Adjust offloaded weights name and save if needed
|
# Adjust offloaded weights name and save if needed
|
||||||
if disk_offload_index is not None and len(disk_offload_index) > 0:
|
if disk_offload_index is not None and len(disk_offload_index) > 0:
|
||||||
@@ -5789,11 +5786,24 @@ def expand_device_map(device_map, param_names):
|
|||||||
return new_device_map
|
return new_device_map
|
||||||
|
|
||||||
|
|
||||||
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict:
|
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict):
|
||||||
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||||
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
||||||
the model, which is actually the loading speed botteneck.
|
the model, which is actually the loading speed botteneck.
|
||||||
Calling this function allows to cut the model loading time by a very large margin.
|
Calling this function allows to cut the model loading time by a very large margin.
|
||||||
|
|
||||||
|
A few facts related to loading speed (taking into account the use of this function):
|
||||||
|
- When loading a model the first time, it is usually slower than the subsequent times, because the OS is very likely
|
||||||
|
to cache the different state dicts (if enough ressources/RAM are available)
|
||||||
|
- Trying to force the OS to cache the files in advance (by e.g. accessing a small portion of them) is really hard,
|
||||||
|
and not a good idea in general as this is low level OS optimizations that depend on ressource usage anyway
|
||||||
|
- As of 18/03/2025, loading a Llama 70B model with TP takes ~1 min without file cache, and ~13s with full file cache.
|
||||||
|
The baseline, i.e. only loading the tensor shards on device and adjusting dtype (i.e. copying them) is ~5s with full cache.
|
||||||
|
These numbers are reported for TP on 4 H100 GPUs.
|
||||||
|
- It is useless to pre-allocate more than the model size in this function (i.e. using an `allocation_factor` > 1) as
|
||||||
|
cudaMalloc is not a bottleneck at all anymore
|
||||||
|
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
|
||||||
|
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
|
||||||
"""
|
"""
|
||||||
# Remove disk and cpu devices, and cast to proper torch.device
|
# Remove disk and cpu devices, and cast to proper torch.device
|
||||||
accelerator_device_map = {
|
accelerator_device_map = {
|
||||||
@@ -5808,31 +5818,26 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
parameter_count = defaultdict(lambda: 0)
|
total_byte_count = defaultdict(lambda: 0)
|
||||||
allocation_factor = 1
|
|
||||||
if torch.distributed.is_initialized() or len(set(accelerator_device_map.values())) >= 2:
|
|
||||||
allocation_factor = 2
|
|
||||||
|
|
||||||
for param_name, device in accelerator_device_map.items():
|
for param_name, device in accelerator_device_map.items():
|
||||||
param = model.get_parameter_or_buffer(param_name)
|
param = model.get_parameter_or_buffer(param_name)
|
||||||
param_size = int(math.prod(param.shape) * allocation_factor)
|
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||||
|
param_byte_count = math.prod(param.shape) * dtype_byte_size(param.dtype)
|
||||||
|
|
||||||
if tp_plan_regex is not None:
|
if tp_plan_regex is not None:
|
||||||
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
|
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
|
||||||
param_size //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
|
param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
|
||||||
|
|
||||||
parameter_count[device] += param_size
|
total_byte_count[device] += param_byte_count
|
||||||
|
|
||||||
dtype = dtype if dtype is not None else torch.float32
|
|
||||||
|
|
||||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||||
for device, param_count in parameter_count.items():
|
for device, byte_count in total_byte_count.items():
|
||||||
max_memory_device = None
|
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
max_memory_device = torch.cuda.mem_get_info(device.index)[0]
|
device_memory = torch.cuda.mem_get_info(device)[0]
|
||||||
# allocate only if we have enough memory
|
# Allow up to 95% of max device memory
|
||||||
if max_memory_device is None or max_memory_device > param_count * dtype_byte_size(dtype):
|
byte_count = min(byte_count, int(0.95 * device_memory))
|
||||||
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
|
# Allocate memory
|
||||||
|
_ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
def get_disk_only_shard_files(device_map, weight_map):
|
def get_disk_only_shard_files(device_map, weight_map):
|
||||||
|
|||||||
Reference in New Issue
Block a user