Load models much faster on accelerator devices!! (#36380)

* caching allocator warmup

* Update modeling_utils.py

* reuse expanded map

* style
This commit is contained in:
Cyril Vallez
2025-02-25 09:41:22 +01:00
committed by GitHub
parent 931e5f4ac3
commit 4b5cf5496d

View File

@@ -21,11 +21,13 @@ import importlib.metadata
import inspect import inspect
import itertools import itertools
import json import json
import math
import os import os
import re import re
import shutil import shutil
import tempfile import tempfile
import warnings import warnings
from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
@@ -4816,8 +4818,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
else: else:
folder = None folder = None
if device_map is not None:
expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
caching_allocator_warmup(model, expanded_device_map, dtype)
if device_map is not None and is_safetensors: if device_map is not None and is_safetensors:
param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) param_device_map = expanded_device_map
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None: if sharded_metadata is None:
archive_file = ( archive_file = (
@@ -5795,6 +5802,30 @@ def expand_device_map(device_map, param_names, start_prefix):
return new_device_map return new_device_map
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict:
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
the model, which is actually the loading speed botteneck.
Calling this function allows to cut the model loading time by a very large margin.
"""
# Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
}
parameter_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
try:
param = model.get_parameter(param_name)
except AttributeError:
param = model.get_buffer(param_name)
parameter_count[device] += math.prod(param.shape)
dtype = dtype if dtype is not None else torch.float32
# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items():
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
""" """
Returns the list of shard files containing only weights offloaded to disk. Returns the list of shard files containing only weights offloaded to disk.