Load models much faster on accelerator devices!! (#36380)
* caching allocator warmup * Update modeling_utils.py * reuse expanded map * style
This commit is contained in:
@@ -21,11 +21,13 @@ import importlib.metadata
|
||||
import inspect
|
||||
import itertools
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@@ -4816,8 +4818,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
|
||||
else:
|
||||
folder = None
|
||||
|
||||
if device_map is not None:
|
||||
expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
|
||||
caching_allocator_warmup(model, expanded_device_map, dtype)
|
||||
|
||||
if device_map is not None and is_safetensors:
|
||||
param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
|
||||
param_device_map = expanded_device_map
|
||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||
if sharded_metadata is None:
|
||||
archive_file = (
|
||||
@@ -5795,6 +5802,30 @@ def expand_device_map(device_map, param_names, start_prefix):
|
||||
return new_device_map
|
||||
|
||||
|
||||
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict:
|
||||
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
||||
the model, which is actually the loading speed botteneck.
|
||||
Calling this function allows to cut the model loading time by a very large margin.
|
||||
"""
|
||||
# Remove disk and cpu devices, and cast to proper torch.device
|
||||
accelerator_device_map = {
|
||||
param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
|
||||
}
|
||||
parameter_count = defaultdict(lambda: 0)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
try:
|
||||
param = model.get_parameter(param_name)
|
||||
except AttributeError:
|
||||
param = model.get_buffer(param_name)
|
||||
parameter_count[device] += math.prod(param.shape)
|
||||
|
||||
dtype = dtype if dtype is not None else torch.float32
|
||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||
for device, param_count in parameter_count.items():
|
||||
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
|
||||
|
||||
|
||||
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
|
||||
"""
|
||||
Returns the list of shard files containing only weights offloaded to disk.
|
||||
|
||||
Reference in New Issue
Block a user