Enhance Model Loading By Providing Parallelism, Uses Optional Env Flag (#36835)

* Get parallel loader working. Include tests.

* Update the tests for parallel loading

* Rename env variables.

* Add docs for parallel model weight loading.

* Touch up parallel model loading docs.

* Touch up parallel model loading docs again.

* Edit comment in test_modeling_utils_parallel_loading.py

* Make sure HF_PARALLEL_LOADING_WORKERS is spelled correctly in modeling_utils.py

* Correct times for parallelized loading, previous times were for a "hot" filesystem

* Update parallel model loading so the spawn method is encapsulated. DRY up the code by leveraging get_submodule.

* Update docs on model loading parallelism so that details on setting the multiprocessing start method are removed, now that the package handles this step internally.

* Fix style on model loading parallelism changes.

* Merge latest version of master's modeling_utils.

* Removed unused variable.

* Fix argument packing for the parallel loader.

* Fix state dict being undefined in the parallel model loader.

* Rename variables used in parallel model loading for clarity. Use get_module_from_name().

* Switch to the use of threads for parallel model loading.

* Update docs for parallel loading.

* Remove the use of json.loads when evaluating HF_ENABLE_PARALLEL_LOADING. Prefer simple casting.

* Move parallelized shard loading into its own function.

* Remove use of is_true(). Favor checking env var true values for HF_ENABLE_PARALLEL_LOADING.

* Update copyright to 2025 in readme for paralell model loading.

* Remove garbage collection line in load_shard_file, implicit garbage collection already occurs.

* Run formatter on modeling_utils.py

* Apply style fixes

* Delete tests/utils/test_modeling_utils_parallel_loading.py

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Aaron V
2025-05-23 12:39:47 -04:00
committed by GitHub
parent 1ed19360b1
commit d5f992f5e6
4 changed files with 234 additions and 76 deletions

View File

@@ -27,6 +27,7 @@ import shutil
import tempfile
import warnings
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
@@ -870,6 +871,116 @@ def _load_state_dict_into_meta_model(
return disk_offload_index, cpu_offload_index
def load_shard_file(args):
(
shard_file,
state_dict,
disk_only_shard_files,
is_hqq_or_bnb,
is_quantized,
device_map,
hf_quantizer,
key_renaming_mapping,
weights_only,
model_to_load,
expected_keys,
reverse_key_renaming_mapping,
disk_offload_folder,
disk_offload_index,
cpu_offload_folder,
cpu_offload_index,
is_offloaded_safetensors,
keep_in_fp32_regex,
unexpected_keys,
device_mesh,
) = args
# Skip the load for shards that only contain disk-offloaded weights
if shard_file in disk_only_shard_files:
return [], disk_offload_index, cpu_offload_index
map_location = "cpu"
if (
shard_file.endswith(".safetensors")
and not is_hqq_or_bnb
and not (is_deepspeed_zero3_enabled() and not is_quantized)
):
map_location = "meta"
elif (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and (
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
)
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
error_msgs = []
if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
# Skip it with fsdp on ranks other than 0
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
shard_file,
expected_keys,
reverse_key_renaming_mapping,
device_map=device_map,
disk_offload_folder=disk_offload_folder,
disk_offload_index=disk_offload_index,
cpu_offload_folder=cpu_offload_folder,
cpu_offload_index=cpu_offload_index,
hf_quantizer=hf_quantizer,
is_safetensors=is_offloaded_safetensors,
keep_in_fp32_regex=keep_in_fp32_regex,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
return error_msgs, disk_offload_index, cpu_offload_index
def load_shard_files_with_threadpool(args_list):
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
# Do not spawn anymore workers than you need
num_workers = min(len(args_list), num_workers)
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
error_msgs = []
with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
futures = [executor.submit(load_shard_file, arg) for arg in args_list]
for future in as_completed(futures):
result = future.result()
(
_error_msgs,
disk_offload_index,
cpu_offload_index,
) = result
error_msgs += _error_msgs
pbar.update(1)
return error_msgs, disk_offload_index, cpu_offload_index
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
path, name = weights_name.rsplit(".", 1)
@@ -4973,9 +5084,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
cpu_offload_folder = tempfile.mkdtemp()
cpu_offload_index = {}
# For nice tqdm bars
if checkpoint_files is not None and len(checkpoint_files) > 1:
checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards")
# To be able to iterate, even if we don't use it if the state_dict is already provided
elif state_dict is not None:
checkpoint_files = [""]
@@ -4993,64 +5101,48 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer)
# Prepare and compatabilize arguments for serial and parallel shard loading
args_list = [
(
shard_file,
state_dict,
disk_only_shard_files,
is_hqq_or_bnb,
is_quantized,
device_map,
hf_quantizer,
key_renaming_mapping,
weights_only,
model_to_load,
expected_keys,
reverse_key_renaming_mapping,
disk_offload_folder,
disk_offload_index,
cpu_offload_folder,
cpu_offload_index,
is_offloaded_safetensors,
keep_in_fp32_regex,
unexpected_keys,
device_mesh,
)
for shard_file in checkpoint_files
]
error_msgs = []
# Iterate on all the shards to load the weights
for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights
if shard_file in disk_only_shard_files:
continue
map_location = "cpu"
if (
shard_file.endswith(".safetensors")
and not is_hqq_or_bnb
and not (is_deepspeed_zero3_enabled() and not is_quantized)
):
map_location = "meta"
elif (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and (
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
)
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
if (
os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
and not is_deepspeed_zero3_enabled()
):
_error_msgs, disk_offload_index, cpu_offload_index = load_shard_files_with_threadpool(args_list)
error_msgs += _error_msgs
else:
if len(args_list) > 1:
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
# Skip it with fsdp on ranks other than 0
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
shard_file,
expected_keys,
reverse_key_renaming_mapping,
device_map=device_map,
disk_offload_folder=disk_offload_folder,
disk_offload_index=disk_offload_index,
cpu_offload_folder=cpu_offload_folder,
cpu_offload_index=cpu_offload_index,
hf_quantizer=hf_quantizer,
is_safetensors=is_offloaded_safetensors,
keep_in_fp32_regex=keep_in_fp32_regex,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
del state_dict
for args in args_list:
_error_msgs, disk_offload_index, cpu_offload_index = load_shard_file(args)
error_msgs += _error_msgs
# Adjust offloaded weights name and save if needed
if disk_offload_index is not None and len(disk_offload_index) > 0: