diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3ebd0eacfa..d7abc3bf7e 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -21,11 +21,13 @@ import importlib.metadata import inspect import itertools import json +import math import os import re import shutil import tempfile import warnings +from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from enum import Enum @@ -4816,8 +4818,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) else: folder = None + + if device_map is not None: + expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) + caching_allocator_warmup(model, expanded_device_map, dtype) + if device_map is not None and is_safetensors: - param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) + param_device_map = expanded_device_map str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" if sharded_metadata is None: archive_file = ( @@ -5795,6 +5802,30 @@ def expand_device_map(device_map, param_names, start_prefix): return new_device_map +def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict: + """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each + device. It allows to have one large call to Malloc, instead of recursively calling it later when loading + the model, which is actually the loading speed botteneck. + Calling this function allows to cut the model loading time by a very large margin. + """ + # Remove disk and cpu devices, and cast to proper torch.device + accelerator_device_map = { + param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"] + } + parameter_count = defaultdict(lambda: 0) + for param_name, device in accelerator_device_map.items(): + try: + param = model.get_parameter(param_name) + except AttributeError: + param = model.get_buffer(param_name) + parameter_count[device] += math.prod(param.shape) + + dtype = dtype if dtype is not None else torch.float32 + # This will kick off the caching allocator to avoid having to Malloc afterwards + for device, param_count in parameter_count.items(): + _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False) + + def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): """ Returns the list of shard files containing only weights offloaded to disk.