Enhance Model Loading By Providing Parallelism, Uses Optional Env Flag (#36835)
* Get parallel loader working. Include tests. * Update the tests for parallel loading * Rename env variables. * Add docs for parallel model weight loading. * Touch up parallel model loading docs. * Touch up parallel model loading docs again. * Edit comment in test_modeling_utils_parallel_loading.py * Make sure HF_PARALLEL_LOADING_WORKERS is spelled correctly in modeling_utils.py * Correct times for parallelized loading, previous times were for a "hot" filesystem * Update parallel model loading so the spawn method is encapsulated. DRY up the code by leveraging get_submodule. * Update docs on model loading parallelism so that details on setting the multiprocessing start method are removed, now that the package handles this step internally. * Fix style on model loading parallelism changes. * Merge latest version of master's modeling_utils. * Removed unused variable. * Fix argument packing for the parallel loader. * Fix state dict being undefined in the parallel model loader. * Rename variables used in parallel model loading for clarity. Use get_module_from_name(). * Switch to the use of threads for parallel model loading. * Update docs for parallel loading. * Remove the use of json.loads when evaluating HF_ENABLE_PARALLEL_LOADING. Prefer simple casting. * Move parallelized shard loading into its own function. * Remove use of is_true(). Favor checking env var true values for HF_ENABLE_PARALLEL_LOADING. * Update copyright to 2025 in readme for paralell model loading. * Remove garbage collection line in load_shard_file, implicit garbage collection already occurs. * Run formatter on modeling_utils.py * Apply style fixes * Delete tests/utils/test_modeling_utils_parallel_loading.py --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
@@ -27,6 +27,7 @@ import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@@ -870,6 +871,116 @@ def _load_state_dict_into_meta_model(
|
||||
return disk_offload_index, cpu_offload_index
|
||||
|
||||
|
||||
def load_shard_file(args):
|
||||
(
|
||||
shard_file,
|
||||
state_dict,
|
||||
disk_only_shard_files,
|
||||
is_hqq_or_bnb,
|
||||
is_quantized,
|
||||
device_map,
|
||||
hf_quantizer,
|
||||
key_renaming_mapping,
|
||||
weights_only,
|
||||
model_to_load,
|
||||
expected_keys,
|
||||
reverse_key_renaming_mapping,
|
||||
disk_offload_folder,
|
||||
disk_offload_index,
|
||||
cpu_offload_folder,
|
||||
cpu_offload_index,
|
||||
is_offloaded_safetensors,
|
||||
keep_in_fp32_regex,
|
||||
unexpected_keys,
|
||||
device_mesh,
|
||||
) = args
|
||||
|
||||
# Skip the load for shards that only contain disk-offloaded weights
|
||||
if shard_file in disk_only_shard_files:
|
||||
return [], disk_offload_index, cpu_offload_index
|
||||
|
||||
map_location = "cpu"
|
||||
if (
|
||||
shard_file.endswith(".safetensors")
|
||||
and not is_hqq_or_bnb
|
||||
and not (is_deepspeed_zero3_enabled() and not is_quantized)
|
||||
):
|
||||
map_location = "meta"
|
||||
elif (
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and (
|
||||
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
|
||||
)
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
|
||||
# If shard_file is "", we use the existing state_dict instead of loading it
|
||||
if shard_file != "":
|
||||
state_dict = load_state_dict(
|
||||
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
||||
)
|
||||
|
||||
# Fix the key names
|
||||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
||||
|
||||
error_msgs = []
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
|
||||
# Skip it with fsdp on ranks other than 0
|
||||
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
||||
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
shard_file,
|
||||
expected_keys,
|
||||
reverse_key_renaming_mapping,
|
||||
device_map=device_map,
|
||||
disk_offload_folder=disk_offload_folder,
|
||||
disk_offload_index=disk_offload_index,
|
||||
cpu_offload_folder=cpu_offload_folder,
|
||||
cpu_offload_index=cpu_offload_index,
|
||||
hf_quantizer=hf_quantizer,
|
||||
is_safetensors=is_offloaded_safetensors,
|
||||
keep_in_fp32_regex=keep_in_fp32_regex,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
|
||||
return error_msgs, disk_offload_index, cpu_offload_index
|
||||
|
||||
|
||||
def load_shard_files_with_threadpool(args_list):
|
||||
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
|
||||
|
||||
# Do not spawn anymore workers than you need
|
||||
num_workers = min(len(args_list), num_workers)
|
||||
|
||||
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
|
||||
|
||||
error_msgs = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
|
||||
futures = [executor.submit(load_shard_file, arg) for arg in args_list]
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
(
|
||||
_error_msgs,
|
||||
disk_offload_index,
|
||||
cpu_offload_index,
|
||||
) = result
|
||||
|
||||
error_msgs += _error_msgs
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
return error_msgs, disk_offload_index, cpu_offload_index
|
||||
|
||||
|
||||
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
if variant is not None:
|
||||
path, name = weights_name.rsplit(".", 1)
|
||||
@@ -4973,9 +5084,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
cpu_offload_folder = tempfile.mkdtemp()
|
||||
cpu_offload_index = {}
|
||||
|
||||
# For nice tqdm bars
|
||||
if checkpoint_files is not None and len(checkpoint_files) > 1:
|
||||
checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards")
|
||||
# To be able to iterate, even if we don't use it if the state_dict is already provided
|
||||
elif state_dict is not None:
|
||||
checkpoint_files = [""]
|
||||
@@ -4993,64 +5101,48 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer)
|
||||
|
||||
# Prepare and compatabilize arguments for serial and parallel shard loading
|
||||
args_list = [
|
||||
(
|
||||
shard_file,
|
||||
state_dict,
|
||||
disk_only_shard_files,
|
||||
is_hqq_or_bnb,
|
||||
is_quantized,
|
||||
device_map,
|
||||
hf_quantizer,
|
||||
key_renaming_mapping,
|
||||
weights_only,
|
||||
model_to_load,
|
||||
expected_keys,
|
||||
reverse_key_renaming_mapping,
|
||||
disk_offload_folder,
|
||||
disk_offload_index,
|
||||
cpu_offload_folder,
|
||||
cpu_offload_index,
|
||||
is_offloaded_safetensors,
|
||||
keep_in_fp32_regex,
|
||||
unexpected_keys,
|
||||
device_mesh,
|
||||
)
|
||||
for shard_file in checkpoint_files
|
||||
]
|
||||
|
||||
error_msgs = []
|
||||
# Iterate on all the shards to load the weights
|
||||
for shard_file in checkpoint_files:
|
||||
# Skip the load for shards that only contain disk-offloaded weights
|
||||
if shard_file in disk_only_shard_files:
|
||||
continue
|
||||
|
||||
map_location = "cpu"
|
||||
if (
|
||||
shard_file.endswith(".safetensors")
|
||||
and not is_hqq_or_bnb
|
||||
and not (is_deepspeed_zero3_enabled() and not is_quantized)
|
||||
):
|
||||
map_location = "meta"
|
||||
elif (
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and (
|
||||
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
|
||||
)
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
if (
|
||||
os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
|
||||
and not is_deepspeed_zero3_enabled()
|
||||
):
|
||||
_error_msgs, disk_offload_index, cpu_offload_index = load_shard_files_with_threadpool(args_list)
|
||||
error_msgs += _error_msgs
|
||||
else:
|
||||
if len(args_list) > 1:
|
||||
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
|
||||
|
||||
# If shard_file is "", we use the existing state_dict instead of loading it
|
||||
if shard_file != "":
|
||||
state_dict = load_state_dict(
|
||||
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
||||
)
|
||||
|
||||
# Fix the key names
|
||||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
|
||||
# Skip it with fsdp on ranks other than 0
|
||||
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
||||
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
shard_file,
|
||||
expected_keys,
|
||||
reverse_key_renaming_mapping,
|
||||
device_map=device_map,
|
||||
disk_offload_folder=disk_offload_folder,
|
||||
disk_offload_index=disk_offload_index,
|
||||
cpu_offload_folder=cpu_offload_folder,
|
||||
cpu_offload_index=cpu_offload_index,
|
||||
hf_quantizer=hf_quantizer,
|
||||
is_safetensors=is_offloaded_safetensors,
|
||||
keep_in_fp32_regex=keep_in_fp32_regex,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
|
||||
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
|
||||
del state_dict
|
||||
for args in args_list:
|
||||
_error_msgs, disk_offload_index, cpu_offload_index = load_shard_file(args)
|
||||
error_msgs += _error_msgs
|
||||
|
||||
# Adjust offloaded weights name and save if needed
|
||||
if disk_offload_index is not None and len(disk_offload_index) > 0:
|
||||
|
||||
Reference in New Issue
Block a user