Load models much faster on accelerator devices!! (#36380)
* caching allocator warmup * Update modeling_utils.py * reuse expanded map * style
This commit is contained in:
@@ -21,11 +21,13 @@ import importlib.metadata
|
|||||||
import inspect
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -4816,8 +4818,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
|
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
|
||||||
else:
|
else:
|
||||||
folder = None
|
folder = None
|
||||||
|
|
||||||
|
if device_map is not None:
|
||||||
|
expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
|
||||||
|
caching_allocator_warmup(model, expanded_device_map, dtype)
|
||||||
|
|
||||||
if device_map is not None and is_safetensors:
|
if device_map is not None and is_safetensors:
|
||||||
param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
|
param_device_map = expanded_device_map
|
||||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||||
if sharded_metadata is None:
|
if sharded_metadata is None:
|
||||||
archive_file = (
|
archive_file = (
|
||||||
@@ -5795,6 +5802,30 @@ def expand_device_map(device_map, param_names, start_prefix):
|
|||||||
return new_device_map
|
return new_device_map
|
||||||
|
|
||||||
|
|
||||||
|
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict:
|
||||||
|
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||||
|
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
||||||
|
the model, which is actually the loading speed botteneck.
|
||||||
|
Calling this function allows to cut the model loading time by a very large margin.
|
||||||
|
"""
|
||||||
|
# Remove disk and cpu devices, and cast to proper torch.device
|
||||||
|
accelerator_device_map = {
|
||||||
|
param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
|
||||||
|
}
|
||||||
|
parameter_count = defaultdict(lambda: 0)
|
||||||
|
for param_name, device in accelerator_device_map.items():
|
||||||
|
try:
|
||||||
|
param = model.get_parameter(param_name)
|
||||||
|
except AttributeError:
|
||||||
|
param = model.get_buffer(param_name)
|
||||||
|
parameter_count[device] += math.prod(param.shape)
|
||||||
|
|
||||||
|
dtype = dtype if dtype is not None else torch.float32
|
||||||
|
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||||
|
for device, param_count in parameter_count.items():
|
||||||
|
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
|
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
|
||||||
"""
|
"""
|
||||||
Returns the list of shard files containing only weights offloaded to disk.
|
Returns the list of shard files containing only weights offloaded to disk.
|
||||||
|
|||||||
Reference in New Issue
Block a user