Remove low_cpu_mem_usage and _fast_init (#36963)
* Remove low_cpu_mem_usage and _fast_init * Update deepspeed.py * Update modeling_utils.py * remove the first 2 tests everywhere * Update test_modeling_common.py * remove what was remaining about fast_init * fix logic and simplify * mismatched keys logic update * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * fix 2 models init_weights * extend to others * remove grad * Update modeling_fsmt.py * init weights in tests * style * Update test_modeling_fsmt.py * more old models * fix more init_weights * copies * fix * style * Update modeling_lxmert.py * fix inits * more and more * more * should finalize * style * Update modeling_dinov2_with_registers.py * fix * Update modeling_encoder_decoder.py * fix * style * Update modeling_lxmert.py * post rebase cleanup * Update modeling_informer.py * back to start for device * fix * add test to detect all failing cases correctly * Update test_modeling_common.py * fix * fix * sam * style * Update modeling_maskformer_swin.py * CIs * CIs * remove test - will add it on separate PR * fix * fix * Update modeling_sam.py * CIs * CIs * CIs * convnext * suggestions * CIs * fix copies after merge --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -73,7 +73,6 @@ from .quantizers import AutoHfQuantizer, HfQuantizer
|
||||
from .quantizers.quantizers_utils import get_module_from_name
|
||||
from .safetensors_conversion import auto_conversion
|
||||
from .utils import (
|
||||
ACCELERATE_MIN_VERSION,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
CONFIG_NAME,
|
||||
@@ -137,7 +136,6 @@ if is_accelerate_available():
|
||||
load_offloaded_weights,
|
||||
offload_weight,
|
||||
save_offload_index,
|
||||
set_module_tensor_to_device,
|
||||
)
|
||||
|
||||
accelerate_version = version.parse(importlib.metadata.version("accelerate"))
|
||||
@@ -208,32 +206,29 @@ TORCH_INIT_FUNCTIONS = {
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_init_weights(_enable=True):
|
||||
def no_init_weights():
|
||||
"""
|
||||
Context manager to globally disable weight initialization to speed up loading large models.
|
||||
|
||||
TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
|
||||
"""
|
||||
global _init_weights
|
||||
old_init_weights = _init_weights
|
||||
|
||||
if _enable:
|
||||
_init_weights = False
|
||||
_init_weights = False
|
||||
|
||||
def _skip_init(*args, **kwargs):
|
||||
pass
|
||||
def _skip_init(*args, **kwargs):
|
||||
pass
|
||||
|
||||
# Save the original initialization functions
|
||||
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
||||
setattr(torch.nn.init, name, _skip_init)
|
||||
|
||||
# # Save the original initialization functions
|
||||
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
||||
setattr(torch.nn.init, name, _skip_init)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_init_weights = old_init_weights
|
||||
if _enable:
|
||||
# # Restore the original initialization functions
|
||||
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
||||
setattr(torch.nn.init, name, init_func)
|
||||
# Restore the original initialization functions
|
||||
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
||||
setattr(torch.nn.init, name, init_func)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -404,37 +399,6 @@ def dtype_byte_size(dtype):
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def check_support_param_buffer_assignment(model_to_load, state_dict):
|
||||
"""
|
||||
Checks if `model_to_load` supports param buffer assignment (such
|
||||
as when loading in empty weights) by first checking
|
||||
if the model explicitly disables it, then by ensuring that the state dict keys
|
||||
are a subset of the model's parameters.
|
||||
|
||||
Note: We fully disable this if we are using `deepspeed`
|
||||
"""
|
||||
if len(state_dict) == 0:
|
||||
return False
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
return False
|
||||
|
||||
# Some models explicitly do not support param buffer assignment
|
||||
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
|
||||
logger.debug(
|
||||
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
|
||||
)
|
||||
return False
|
||||
|
||||
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
|
||||
first_key = next(iter(model_to_load.state_dict().keys()))
|
||||
if first_key in state_dict:
|
||||
return state_dict[first_key].dtype == model_to_load.state_dict()[first_key].dtype
|
||||
|
||||
# For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`)
|
||||
return False
|
||||
|
||||
|
||||
def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
"""
|
||||
This is the same as
|
||||
@@ -750,6 +714,13 @@ def _infer_parameter_dtype(
|
||||
return old_param is not None and old_param.is_contiguous(), casting_dtype
|
||||
|
||||
|
||||
def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
|
||||
"""Cast a single parameter `param_name` into the `model`, with value `tensor`."""
|
||||
module, param_type = get_module_from_name(model, param_name)
|
||||
# This will check potential shape mismatch if skipped before
|
||||
module.load_state_dict({param_type: tensor}, strict=False, assign=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _load_state_dict_into_meta_model(
|
||||
model: "PreTrainedModel",
|
||||
@@ -857,17 +828,12 @@ def _load_state_dict_into_meta_model(
|
||||
):
|
||||
if is_fsdp_enabled():
|
||||
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
||||
module, param_type = get_module_from_name(model, param_name)
|
||||
|
||||
# avoid tied weights
|
||||
if param.data_ptr() in data_ptrs:
|
||||
param = param.clone()
|
||||
|
||||
module.load_state_dict(
|
||||
{param_type: param.to(param_device)},
|
||||
strict=False,
|
||||
assign=True,
|
||||
)
|
||||
_load_parameter_into_model(model, param_name, param.to(param_device))
|
||||
|
||||
# Add `data_ptr` of `model.state_dict()[param_name]` to avoid tied weights
|
||||
data_ptrs.add(model.state_dict()[param_name].data_ptr())
|
||||
@@ -1397,18 +1363,7 @@ def _find_missing_and_unexpected_keys(
|
||||
if has_inv_freq_buffers:
|
||||
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
|
||||
|
||||
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model.state_dict().items():
|
||||
id_tensor = id_tensor_storage(tensor)
|
||||
ptrs[id_tensor].append(name)
|
||||
|
||||
# These are all the pointers of shared tensors.
|
||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||
else:
|
||||
# id function doesn't work for meta tensor so we need this function
|
||||
tied_params = find_tied_parameters(model)
|
||||
|
||||
tied_params = find_tied_parameters(model)
|
||||
for group in tied_params:
|
||||
missing_in_group = [k for k in missing_keys if k in group]
|
||||
if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
|
||||
@@ -1430,29 +1385,59 @@ def _find_missing_and_unexpected_keys(
|
||||
|
||||
|
||||
def _find_mismatched_keys(
|
||||
model_to_load: "PreTrainedModel",
|
||||
state_dict: Dict,
|
||||
model: "PreTrainedModel",
|
||||
state_dict: Optional[Dict],
|
||||
checkpoint_files: Optional[List[str]],
|
||||
ignore_mismatched_sizes: bool,
|
||||
prefix: str,
|
||||
) -> List:
|
||||
"""Find mismatch of shapes between the model parameters and the loaded state dict, and optionally remove the
|
||||
problematic keys from `state_dict` if `ignore_mismatched_sizes` is `True`."""
|
||||
keys_to_rename_mapping: Dict[str, str],
|
||||
is_quantized: bool,
|
||||
weights_only: bool,
|
||||
) -> Tuple[List[str], List[Tuple[int, int]]]:
|
||||
"""
|
||||
Find potential shape mismatch between the different state dicts and the model parameters, but only if `ignore_mismatched_sizes`
|
||||
is True. Otherwise, return immediately and any shape mismatch that may exist will be raised later on. This avoids checking
|
||||
every parameter in advance, as shape mismatch are extremely rare in practice. If we want to ignore them however, we do
|
||||
need to check in advance as we need to know which parameters we need to move back from meta to cpu, and initialize
|
||||
correctly. Indeed, as our model initialization takes place at the module level, and not the weight level, in the
|
||||
case of a sharded checkpoint we cannot correctly initialize the weights according to `model._init_weights()` if we perform
|
||||
this check on each state dict at loading time (after the first loaded checkpoint, there are no way to initialize only the
|
||||
mismatched weights if any, without overwriting the previously loaded weights as well because all the module will be
|
||||
initialized, not only the weights that are mismatched).
|
||||
"""
|
||||
|
||||
# An error will be raised later on anyway if there is a mismatch - this avoids running the rest of this function
|
||||
# if there are no mismatch (which is almost always the case)
|
||||
if not ignore_mismatched_sizes:
|
||||
return [], []
|
||||
|
||||
if state_dict is not None:
|
||||
checkpoint_files = [""]
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
model_state_dict = model_to_load.state_dict()
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
for key in state_dict_keys:
|
||||
if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape:
|
||||
if state_dict[key].shape[-1] == 1 and state_dict[key].numel() * 2 == model_state_dict[key].numel():
|
||||
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
|
||||
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
|
||||
pass
|
||||
else:
|
||||
# Add prefix if we removed it before, to add the correct state dict key to the warnings
|
||||
key_with_prefix = prefix + key
|
||||
mismatched_keys.append((key_with_prefix, state_dict[key].shape, model_state_dict[key].shape))
|
||||
del state_dict[key]
|
||||
return mismatched_keys
|
||||
mismatched_shapes = []
|
||||
for shard_file in checkpoint_files:
|
||||
# If shard_file is "", we use the existing state_dict instead of loading it
|
||||
if shard_file != "":
|
||||
state_dict = load_state_dict(
|
||||
shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only
|
||||
)
|
||||
|
||||
# Fix the key names
|
||||
new_state_dict = {keys_to_rename_mapping[k]: v for k, v in state_dict.items() if k in keys_to_rename_mapping}
|
||||
|
||||
for key in new_state_dict.keys():
|
||||
if key in model_state_dict and new_state_dict[key].shape != model_state_dict[key].shape:
|
||||
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
|
||||
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
|
||||
if not (
|
||||
new_state_dict[key].shape[-1] == 1
|
||||
and new_state_dict[key].numel() * 2 == model_state_dict[key].numel()
|
||||
):
|
||||
mismatched_keys.append(key)
|
||||
mismatched_shapes.append((new_state_dict[key].shape, model_state_dict[key].shape))
|
||||
|
||||
return mismatched_keys, mismatched_shapes
|
||||
|
||||
|
||||
class PipelineParallel(Enum):
|
||||
@@ -3773,13 +3758,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
@classmethod
|
||||
def get_init_context(
|
||||
cls: Type[SpecificPreTrainedModelType],
|
||||
_fast_init=True,
|
||||
is_quantized=None,
|
||||
_is_ds_init_called=None,
|
||||
low_cpu_mem_usage=True,
|
||||
):
|
||||
init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
|
||||
import deepspeed
|
||||
|
||||
@@ -3787,13 +3768,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
init_contexts = [
|
||||
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
|
||||
set_zero3_state(),
|
||||
] + init_contexts
|
||||
elif low_cpu_mem_usage:
|
||||
if not is_accelerate_available():
|
||||
raise ImportError(
|
||||
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||
)
|
||||
init_contexts.append(init_empty_weights())
|
||||
no_init_weights(),
|
||||
]
|
||||
else:
|
||||
init_contexts = [no_init_weights(), init_empty_weights()]
|
||||
|
||||
if is_deepspeed_zero3_enabled() and is_quantized:
|
||||
init_contexts.append(set_quantized_state())
|
||||
@@ -3829,10 +3807,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
||||
weights are discarded.
|
||||
|
||||
If model weights are the same precision as the base model (and is a supported model), weights will be lazily loaded
|
||||
in using the `meta` device and brought into memory once an input is passed through that layer regardless of
|
||||
`low_cpu_mem_usage`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
@@ -3910,31 +3884,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
|
||||
|
||||
</Tip>
|
||||
_fast_init(`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to disable fast initialization.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
One should only disable *_fast_init* to ensure backwards compatibility with `transformers.__version__ <
|
||||
4.6.0` for seeded model initialization. This argument will be removed at the next major version. See
|
||||
[pull request 11471](https://github.com/huggingface/transformers/pull/11471) for more information.
|
||||
|
||||
</Tip>
|
||||
attn_implementation (`str`, *optional*):
|
||||
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
|
||||
|
||||
> Parameters for big model inference
|
||||
|
||||
low_cpu_mem_usage(`bool`, *optional*):
|
||||
Tries not to use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Generally should be combined with a `device_map` (such as `"auto"`) for best results.
|
||||
This is an experimental feature and a subject to change at any moment.
|
||||
</Tip>
|
||||
If the model weights are in the same precision as the model loaded in, `low_cpu_mem_usage` (without
|
||||
`device_map`) is redundant and will not provide any benefit in regards to CPU memory usage. However,
|
||||
this should still be enabled if you are passing in a `device_map`.
|
||||
</Tip>
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under a specific `dtype`. The different options
|
||||
are:
|
||||
@@ -4045,37 +4000,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
|
||||
>>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)
|
||||
```
|
||||
|
||||
* `low_cpu_mem_usage` algorithm:
|
||||
|
||||
This is an experimental function that loads the model using ~1x model size CPU memory
|
||||
|
||||
Here is how it works:
|
||||
|
||||
1. save which state_dict keys we have
|
||||
2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory
|
||||
3. after the model has been instantiated switch to the meta device all params/buffers that
|
||||
are going to be replaced from the loaded state_dict
|
||||
4. load state_dict 2nd time
|
||||
5. replace the params/buffers from the state_dict
|
||||
|
||||
Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors
|
||||
|
||||
"""
|
||||
state_dict = kwargs.pop("state_dict", None)
|
||||
from_tf = kwargs.pop("from_tf", False)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
_ = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
_ = kwargs.pop("trust_remote_code", None)
|
||||
_ = kwargs.pop("mirror", None)
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
_fast_init = kwargs.pop("_fast_init", True)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
offload_folder = kwargs.pop("offload_folder", None)
|
||||
@@ -4094,6 +4028,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
tp_plan = kwargs.pop("tp_plan", None)
|
||||
key_mapping = kwargs.pop("key_mapping", None)
|
||||
# Not used anymore -- remove them from the kwargs
|
||||
_ = kwargs.pop("resume_download", None)
|
||||
_ = kwargs.pop("trust_remote_code", None)
|
||||
_ = kwargs.pop("mirror", None)
|
||||
_ = kwargs.pop("_fast_init", True)
|
||||
_ = kwargs.pop("low_cpu_mem_usage", None)
|
||||
|
||||
if state_dict is not None and (pretrained_model_name_or_path is not None or gguf_file is not None):
|
||||
raise ValueError(
|
||||
@@ -4156,9 +4096,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
|
||||
|
||||
if is_fsdp_enabled():
|
||||
low_cpu_mem_usage = True
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
@@ -4240,20 +4177,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
device_map = {"": device_map}
|
||||
|
||||
if device_map is not None:
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
elif not low_cpu_mem_usage:
|
||||
raise ValueError("Passing along a `device_map` or a `tp_plan` requires `low_cpu_mem_usage=True`")
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError(
|
||||
"DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`."
|
||||
)
|
||||
elif not is_accelerate_available():
|
||||
raise ImportError(
|
||||
f"Using `low_cpu_mem_usage=True`, a `device_map` or a `tp_plan` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||
)
|
||||
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
|
||||
|
||||
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
|
||||
if load_in_4bit or load_in_8bit:
|
||||
@@ -4355,10 +4280,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
|
||||
else:
|
||||
user_agent["quant"] = hf_quantizer.quantization_config.quant_method
|
||||
# Force-set to `True` for more mem efficiency
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
logger.warning("`low_cpu_mem_usage` was None, now default to True since model is quantized.")
|
||||
|
||||
if gguf_file is not None and hf_quantizer is not None:
|
||||
raise ValueError(
|
||||
@@ -4438,8 +4359,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
state_dict = load_gguf_checkpoint(checkpoint_files[0], return_tensors=True, model_to_load=dummy_model)[
|
||||
"tensors"
|
||||
]
|
||||
# Force it if is not already the case
|
||||
low_cpu_mem_usage = True
|
||||
|
||||
# Find the correct dtype based on current state
|
||||
config, torch_dtype, dtype_orig = _get_torch_dtype(
|
||||
@@ -4449,7 +4368,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
# Instantiate model.
|
||||
model_init_context = cls.get_init_context(_fast_init, is_quantized, _is_ds_init_called, low_cpu_mem_usage)
|
||||
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
|
||||
|
||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
||||
if not getattr(config, "_attn_implementation_autoset", False):
|
||||
@@ -4480,8 +4399,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if model._keep_in_fp32_modules is not None and (
|
||||
torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
||||
):
|
||||
# Only the path with `low_cpu_mem_usage` will check every param for the correct dtype
|
||||
low_cpu_mem_usage = True
|
||||
# We need to match exact layers, so we add either `.` on each side, or start/end of string
|
||||
keep_in_fp32_regex = re.compile(
|
||||
"|".join([rf"((^|\.){module}($|\.))" for module in model._keep_in_fp32_modules])
|
||||
@@ -4526,7 +4443,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
sharded_metadata=sharded_metadata,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
device_map=device_map,
|
||||
disk_offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
@@ -4536,7 +4452,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
device_mesh=device_mesh,
|
||||
key_mapping=key_mapping,
|
||||
weights_only=weights_only,
|
||||
_fast_init=_fast_init,
|
||||
)
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
@@ -4735,7 +4650,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
pretrained_model_name_or_path: Optional[str],
|
||||
ignore_mismatched_sizes: bool = False,
|
||||
sharded_metadata: Optional[Dict] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
device_map: Optional[Dict] = None,
|
||||
disk_offload_folder: Optional[str] = None,
|
||||
offload_state_dict: Optional[bool] = None,
|
||||
@@ -4745,7 +4659,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
||||
key_mapping: Optional[Dict[str, str]] = None,
|
||||
weights_only: bool = True,
|
||||
_fast_init: bool = True,
|
||||
):
|
||||
# Useful flags
|
||||
is_quantized = hf_quantizer is not None
|
||||
@@ -4787,20 +4700,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
hf_quantizer,
|
||||
device_map,
|
||||
)
|
||||
# Find all the keys with shape mismatch (if we ignore the mismatch, the weights need to be newly initialized the
|
||||
# same way as missing keys)
|
||||
mismatched_keys, mismatched_shapes = _find_mismatched_keys(
|
||||
model,
|
||||
state_dict,
|
||||
checkpoint_files,
|
||||
ignore_mismatched_sizes,
|
||||
key_renaming_mapping,
|
||||
is_quantized,
|
||||
weights_only,
|
||||
)
|
||||
|
||||
# Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as
|
||||
# they are not in the loaded state dict)
|
||||
if low_cpu_mem_usage:
|
||||
model._move_missing_keys_from_meta_to_cpu(missing_keys, unexpected_keys, dtype, hf_quantizer)
|
||||
# In this case we also need to move everything back
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
||||
for key, param in model.state_dict().items():
|
||||
if param.device == torch.device("meta"):
|
||||
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
|
||||
# We need to update both the mapping and the list of checkpoint keys to remove the mismatched ones
|
||||
key_renaming_mapping = {k: v for k, v in key_renaming_mapping.items() if v not in mismatched_keys}
|
||||
checkpoint_keys = list(key_renaming_mapping.values())
|
||||
|
||||
# correctly initialize the missing keys if it was skipped before
|
||||
if _fast_init or low_cpu_mem_usage:
|
||||
model._initialize_missing_keys(checkpoint_keys, ignore_mismatched_sizes, is_quantized)
|
||||
# Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
|
||||
# loading the weights as they are not in the loaded state dict)
|
||||
model._move_missing_keys_from_meta_to_cpu(missing_keys + mismatched_keys, unexpected_keys, dtype, hf_quantizer)
|
||||
|
||||
# correctly initialize the missing (and potentially mismatched) keys
|
||||
model._initialize_missing_keys(checkpoint_keys, ignore_mismatched_sizes, is_quantized)
|
||||
|
||||
# Set some modules to fp32 if needed
|
||||
if keep_in_fp32_regex is not None:
|
||||
@@ -4907,7 +4828,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||
|
||||
error_msgs = []
|
||||
mismatched_keys = []
|
||||
# Iterate on all the shards to load the weights
|
||||
for shard_file in checkpoint_files:
|
||||
# Skip the load for shards that only contain disk-offloaded weights
|
||||
@@ -4915,16 +4835,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
continue
|
||||
|
||||
map_location = "cpu"
|
||||
if low_cpu_mem_usage:
|
||||
if shard_file.endswith(".safetensors"):
|
||||
map_location = "meta"
|
||||
elif (
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
if shard_file.endswith(".safetensors"):
|
||||
map_location = "meta"
|
||||
elif (
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
|
||||
# If shard_file is "", we use the existing state_dict instead of loading it
|
||||
if shard_file != "":
|
||||
@@ -4935,41 +4854,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Fix the key names
|
||||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
||||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
mismatched_keys += _find_mismatched_keys(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
ignore_mismatched_sizes,
|
||||
prefix if loading_base_model_from_task_state_dict else "",
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
# Skip it with fsdp on ranks other than 0
|
||||
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
||||
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
shard_file,
|
||||
expected_keys,
|
||||
reverse_key_renaming_mapping,
|
||||
device_map=device_map,
|
||||
disk_offload_folder=disk_offload_folder,
|
||||
disk_offload_index=disk_offload_index,
|
||||
cpu_offload_folder=cpu_offload_folder,
|
||||
cpu_offload_index=cpu_offload_index,
|
||||
hf_quantizer=hf_quantizer,
|
||||
is_safetensors=is_offloaded_safetensors,
|
||||
keep_in_fp32_regex=keep_in_fp32_regex,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
else:
|
||||
assign_params = check_support_param_buffer_assignment(model_to_load, state_dict)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_params)
|
||||
else:
|
||||
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
|
||||
# Skip it with fsdp on ranks other than 0
|
||||
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
||||
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
shard_file,
|
||||
expected_keys,
|
||||
reverse_key_renaming_mapping,
|
||||
device_map=device_map,
|
||||
disk_offload_folder=disk_offload_folder,
|
||||
disk_offload_index=disk_offload_index,
|
||||
cpu_offload_folder=cpu_offload_folder,
|
||||
cpu_offload_index=cpu_offload_index,
|
||||
hf_quantizer=hf_quantizer,
|
||||
is_safetensors=is_offloaded_safetensors,
|
||||
keep_in_fp32_regex=keep_in_fp32_regex,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
|
||||
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
|
||||
del state_dict
|
||||
@@ -5068,7 +4973,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
mismatched_warning = "\n".join(
|
||||
[
|
||||
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||
for key, shape1, shape2 in mismatched_keys
|
||||
for key, (shape1, shape2) in zip(mismatched_keys, mismatched_shapes)
|
||||
]
|
||||
)
|
||||
logger.warning(
|
||||
@@ -5323,19 +5228,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"""
|
||||
is_quantized = hf_quantizer is not None
|
||||
|
||||
# In this case we need to move everything back
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
||||
# We only do it for the parameters, as the buffers are not initialized on the meta device by default
|
||||
for key, param in self.named_parameters():
|
||||
value = torch.empty_like(param, dtype=dtype, device="cpu")
|
||||
_load_parameter_into_model(self, key, value)
|
||||
return
|
||||
|
||||
model_state_dict = self.state_dict()
|
||||
for key in missing_keys:
|
||||
param = model_state_dict[key]
|
||||
# Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them
|
||||
if param.device == torch.device("meta"):
|
||||
# upcast in fp32 if any
|
||||
target_dtype = dtype
|
||||
value = torch.empty(*param.size(), dtype=target_dtype)
|
||||
value = torch.empty_like(param, dtype=dtype, device="cpu")
|
||||
if (
|
||||
not is_quantized
|
||||
or (getattr(hf_quantizer, "requires_parameters_quantization", False))
|
||||
or not hf_quantizer.check_quantized_param(self, param_value=value, param_name=key, state_dict={})
|
||||
):
|
||||
set_module_tensor_to_device(self, key, "cpu", value)
|
||||
_load_parameter_into_model(self, key, value)
|
||||
else:
|
||||
hf_quantizer.create_quantized_param(self, value, key, "cpu", model_state_dict, unexpected_keys)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user