Remove low_cpu_mem_usage and _fast_init (#36963)

* Remove low_cpu_mem_usage and _fast_init

* Update deepspeed.py

* Update modeling_utils.py

* remove the first 2 tests everywhere

* Update test_modeling_common.py

* remove what was remaining about fast_init

* fix logic and simplify

* mismatched keys logic update

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* fix 2 models init_weights

* extend to others

* remove grad

* Update modeling_fsmt.py

* init weights in tests

* style

* Update test_modeling_fsmt.py

* more old models

* fix more init_weights

* copies

* fix

* style

* Update modeling_lxmert.py

* fix inits

* more and more

* more

* should finalize

* style

* Update modeling_dinov2_with_registers.py

* fix

* Update modeling_encoder_decoder.py

* fix

* style

* Update modeling_lxmert.py

* post rebase cleanup

* Update modeling_informer.py

* back to start for device

* fix

* add test to detect all failing cases correctly

* Update test_modeling_common.py

* fix

* fix

* sam

* style

* Update modeling_maskformer_swin.py

* CIs

* CIs

* remove test - will add it on separate PR

* fix

* fix

* Update modeling_sam.py

* CIs

* CIs

* CIs

* convnext

* suggestions

* CIs

* fix copies after merge

---------

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Cyril Vallez
2025-03-31 17:18:43 +02:00
committed by GitHub
parent 8805600406
commit f304318f5f
128 changed files with 464 additions and 1165 deletions

View File

@@ -73,7 +73,6 @@ from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
from .safetensors_conversion import auto_conversion
from .utils import (
ACCELERATE_MIN_VERSION,
ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
CONFIG_NAME,
@@ -137,7 +136,6 @@ if is_accelerate_available():
load_offloaded_weights,
offload_weight,
save_offload_index,
set_module_tensor_to_device,
)
accelerate_version = version.parse(importlib.metadata.version("accelerate"))
@@ -208,32 +206,29 @@ TORCH_INIT_FUNCTIONS = {
@contextmanager
def no_init_weights(_enable=True):
def no_init_weights():
"""
Context manager to globally disable weight initialization to speed up loading large models.
TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
"""
global _init_weights
old_init_weights = _init_weights
if _enable:
_init_weights = False
_init_weights = False
def _skip_init(*args, **kwargs):
pass
def _skip_init(*args, **kwargs):
pass
# Save the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, _skip_init)
# # Save the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, _skip_init)
try:
yield
finally:
_init_weights = old_init_weights
if _enable:
# # Restore the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, init_func)
# Restore the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, init_func)
@contextmanager
@@ -404,37 +399,6 @@ def dtype_byte_size(dtype):
return bit_size // 8
def check_support_param_buffer_assignment(model_to_load, state_dict):
"""
Checks if `model_to_load` supports param buffer assignment (such
as when loading in empty weights) by first checking
if the model explicitly disables it, then by ensuring that the state dict keys
are a subset of the model's parameters.
Note: We fully disable this if we are using `deepspeed`
"""
if len(state_dict) == 0:
return False
if is_deepspeed_zero3_enabled():
return False
# Some models explicitly do not support param buffer assignment
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = next(iter(model_to_load.state_dict().keys()))
if first_key in state_dict:
return state_dict[first_key].dtype == model_to_load.state_dict()[first_key].dtype
# For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`)
return False
def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
"""
This is the same as
@@ -750,6 +714,13 @@ def _infer_parameter_dtype(
return old_param is not None and old_param.is_contiguous(), casting_dtype
def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
"""Cast a single parameter `param_name` into the `model`, with value `tensor`."""
module, param_type = get_module_from_name(model, param_name)
# This will check potential shape mismatch if skipped before
module.load_state_dict({param_type: tensor}, strict=False, assign=True)
@torch.no_grad()
def _load_state_dict_into_meta_model(
model: "PreTrainedModel",
@@ -857,17 +828,12 @@ def _load_state_dict_into_meta_model(
):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
module, param_type = get_module_from_name(model, param_name)
# avoid tied weights
if param.data_ptr() in data_ptrs:
param = param.clone()
module.load_state_dict(
{param_type: param.to(param_device)},
strict=False,
assign=True,
)
_load_parameter_into_model(model, param_name, param.to(param_device))
# Add `data_ptr` of `model.state_dict()[param_name]` to avoid tied weights
data_ptrs.add(model.state_dict()[param_name].data_ptr())
@@ -1397,18 +1363,7 @@ def _find_missing_and_unexpected_keys(
if has_inv_freq_buffers:
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor)
ptrs[id_tensor].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
else:
# id function doesn't work for meta tensor so we need this function
tied_params = find_tied_parameters(model)
tied_params = find_tied_parameters(model)
for group in tied_params:
missing_in_group = [k for k in missing_keys if k in group]
if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
@@ -1430,29 +1385,59 @@ def _find_missing_and_unexpected_keys(
def _find_mismatched_keys(
model_to_load: "PreTrainedModel",
state_dict: Dict,
model: "PreTrainedModel",
state_dict: Optional[Dict],
checkpoint_files: Optional[List[str]],
ignore_mismatched_sizes: bool,
prefix: str,
) -> List:
"""Find mismatch of shapes between the model parameters and the loaded state dict, and optionally remove the
problematic keys from `state_dict` if `ignore_mismatched_sizes` is `True`."""
keys_to_rename_mapping: Dict[str, str],
is_quantized: bool,
weights_only: bool,
) -> Tuple[List[str], List[Tuple[int, int]]]:
"""
Find potential shape mismatch between the different state dicts and the model parameters, but only if `ignore_mismatched_sizes`
is True. Otherwise, return immediately and any shape mismatch that may exist will be raised later on. This avoids checking
every parameter in advance, as shape mismatch are extremely rare in practice. If we want to ignore them however, we do
need to check in advance as we need to know which parameters we need to move back from meta to cpu, and initialize
correctly. Indeed, as our model initialization takes place at the module level, and not the weight level, in the
case of a sharded checkpoint we cannot correctly initialize the weights according to `model._init_weights()` if we perform
this check on each state dict at loading time (after the first loaded checkpoint, there are no way to initialize only the
mismatched weights if any, without overwriting the previously loaded weights as well because all the module will be
initialized, not only the weights that are mismatched).
"""
# An error will be raised later on anyway if there is a mismatch - this avoids running the rest of this function
# if there are no mismatch (which is almost always the case)
if not ignore_mismatched_sizes:
return [], []
if state_dict is not None:
checkpoint_files = [""]
model_state_dict = model.state_dict()
mismatched_keys = []
if ignore_mismatched_sizes:
model_state_dict = model_to_load.state_dict()
state_dict_keys = list(state_dict.keys())
for key in state_dict_keys:
if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape:
if state_dict[key].shape[-1] == 1 and state_dict[key].numel() * 2 == model_state_dict[key].numel():
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
pass
else:
# Add prefix if we removed it before, to add the correct state dict key to the warnings
key_with_prefix = prefix + key
mismatched_keys.append((key_with_prefix, state_dict[key].shape, model_state_dict[key].shape))
del state_dict[key]
return mismatched_keys
mismatched_shapes = []
for shard_file in checkpoint_files:
# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only
)
# Fix the key names
new_state_dict = {keys_to_rename_mapping[k]: v for k, v in state_dict.items() if k in keys_to_rename_mapping}
for key in new_state_dict.keys():
if key in model_state_dict and new_state_dict[key].shape != model_state_dict[key].shape:
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
if not (
new_state_dict[key].shape[-1] == 1
and new_state_dict[key].numel() * 2 == model_state_dict[key].numel()
):
mismatched_keys.append(key)
mismatched_shapes.append((new_state_dict[key].shape, model_state_dict[key].shape))
return mismatched_keys, mismatched_shapes
class PipelineParallel(Enum):
@@ -3773,13 +3758,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
@classmethod
def get_init_context(
cls: Type[SpecificPreTrainedModelType],
_fast_init=True,
is_quantized=None,
_is_ds_init_called=None,
low_cpu_mem_usage=True,
):
init_contexts = [no_init_weights(_enable=_fast_init)]
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
import deepspeed
@@ -3787,13 +3768,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
init_contexts = [
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
set_zero3_state(),
] + init_contexts
elif low_cpu_mem_usage:
if not is_accelerate_available():
raise ImportError(
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
init_contexts.append(init_empty_weights())
no_init_weights(),
]
else:
init_contexts = [no_init_weights(), init_empty_weights()]
if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())
@@ -3829,10 +3807,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
If model weights are the same precision as the base model (and is a supported model), weights will be lazily loaded
in using the `meta` device and brought into memory once an input is passed through that layer regardless of
`low_cpu_mem_usage`.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
@@ -3910,31 +3884,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
</Tip>
_fast_init(`bool`, *optional*, defaults to `True`):
Whether or not to disable fast initialization.
<Tip warning={true}>
One should only disable *_fast_init* to ensure backwards compatibility with `transformers.__version__ <
4.6.0` for seeded model initialization. This argument will be removed at the next major version. See
[pull request 11471](https://github.com/huggingface/transformers/pull/11471) for more information.
</Tip>
attn_implementation (`str`, *optional*):
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
> Parameters for big model inference
low_cpu_mem_usage(`bool`, *optional*):
Tries not to use more than 1x model size in CPU memory (including peak memory) while loading the model.
Generally should be combined with a `device_map` (such as `"auto"`) for best results.
This is an experimental feature and a subject to change at any moment.
</Tip>
If the model weights are in the same precision as the model loaded in, `low_cpu_mem_usage` (without
`device_map`) is redundant and will not provide any benefit in regards to CPU memory usage. However,
this should still be enabled if you are passing in a `device_map`.
</Tip>
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under a specific `dtype`. The different options
are:
@@ -4045,37 +4000,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
>>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)
```
* `low_cpu_mem_usage` algorithm:
This is an experimental function that loads the model using ~1x model size CPU memory
Here is how it works:
1. save which state_dict keys we have
2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory
3. after the model has been instantiated switch to the meta device all params/buffers that
are going to be replaced from the loaded state_dict
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict
Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors
"""
state_dict = kwargs.pop("state_dict", None)
from_tf = kwargs.pop("from_tf", False)
from_flax = kwargs.pop("from_flax", False)
_ = kwargs.pop("resume_download", None)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
use_auth_token = kwargs.pop("use_auth_token", None)
_ = kwargs.pop("trust_remote_code", None)
_ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
@@ -4094,6 +4028,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
key_mapping = kwargs.pop("key_mapping", None)
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
_ = kwargs.pop("trust_remote_code", None)
_ = kwargs.pop("mirror", None)
_ = kwargs.pop("_fast_init", True)
_ = kwargs.pop("low_cpu_mem_usage", None)
if state_dict is not None and (pretrained_model_name_or_path is not None or gguf_file is not None):
raise ValueError(
@@ -4156,9 +4096,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
if is_fsdp_enabled():
low_cpu_mem_usage = True
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
@@ -4240,20 +4177,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
device_map = {"": device_map}
if device_map is not None:
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
elif not low_cpu_mem_usage:
raise ValueError("Passing along a `device_map` or a `tp_plan` requires `low_cpu_mem_usage=True`")
if low_cpu_mem_usage:
if is_deepspeed_zero3_enabled():
raise ValueError(
"DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`."
)
elif not is_accelerate_available():
raise ImportError(
f"Using `low_cpu_mem_usage=True`, a `device_map` or a `tp_plan` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
if load_in_4bit or load_in_8bit:
@@ -4355,10 +4280,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
else:
user_agent["quant"] = hf_quantizer.quantization_config.quant_method
# Force-set to `True` for more mem efficiency
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
logger.warning("`low_cpu_mem_usage` was None, now default to True since model is quantized.")
if gguf_file is not None and hf_quantizer is not None:
raise ValueError(
@@ -4438,8 +4359,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict = load_gguf_checkpoint(checkpoint_files[0], return_tensors=True, model_to_load=dummy_model)[
"tensors"
]
# Force it if is not already the case
low_cpu_mem_usage = True
# Find the correct dtype based on current state
config, torch_dtype, dtype_orig = _get_torch_dtype(
@@ -4449,7 +4368,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config.name_or_path = pretrained_model_name_or_path
# Instantiate model.
model_init_context = cls.get_init_context(_fast_init, is_quantized, _is_ds_init_called, low_cpu_mem_usage)
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
if not getattr(config, "_attn_implementation_autoset", False):
@@ -4480,8 +4399,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if model._keep_in_fp32_modules is not None and (
torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
):
# Only the path with `low_cpu_mem_usage` will check every param for the correct dtype
low_cpu_mem_usage = True
# We need to match exact layers, so we add either `.` on each side, or start/end of string
keep_in_fp32_regex = re.compile(
"|".join([rf"((^|\.){module}($|\.))" for module in model._keep_in_fp32_modules])
@@ -4526,7 +4443,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
disk_offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
@@ -4536,7 +4452,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
device_mesh=device_mesh,
key_mapping=key_mapping,
weights_only=weights_only,
_fast_init=_fast_init,
)
# make sure token embedding weights are still tied if needed
@@ -4735,7 +4650,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
pretrained_model_name_or_path: Optional[str],
ignore_mismatched_sizes: bool = False,
sharded_metadata: Optional[Dict] = None,
low_cpu_mem_usage: bool = False,
device_map: Optional[Dict] = None,
disk_offload_folder: Optional[str] = None,
offload_state_dict: Optional[bool] = None,
@@ -4745,7 +4659,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
key_mapping: Optional[Dict[str, str]] = None,
weights_only: bool = True,
_fast_init: bool = True,
):
# Useful flags
is_quantized = hf_quantizer is not None
@@ -4787,20 +4700,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
hf_quantizer,
device_map,
)
# Find all the keys with shape mismatch (if we ignore the mismatch, the weights need to be newly initialized the
# same way as missing keys)
mismatched_keys, mismatched_shapes = _find_mismatched_keys(
model,
state_dict,
checkpoint_files,
ignore_mismatched_sizes,
key_renaming_mapping,
is_quantized,
weights_only,
)
# Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as
# they are not in the loaded state dict)
if low_cpu_mem_usage:
model._move_missing_keys_from_meta_to_cpu(missing_keys, unexpected_keys, dtype, hf_quantizer)
# In this case we also need to move everything back
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model.state_dict().items():
if param.device == torch.device("meta"):
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
# We need to update both the mapping and the list of checkpoint keys to remove the mismatched ones
key_renaming_mapping = {k: v for k, v in key_renaming_mapping.items() if v not in mismatched_keys}
checkpoint_keys = list(key_renaming_mapping.values())
# correctly initialize the missing keys if it was skipped before
if _fast_init or low_cpu_mem_usage:
model._initialize_missing_keys(checkpoint_keys, ignore_mismatched_sizes, is_quantized)
# Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
# loading the weights as they are not in the loaded state dict)
model._move_missing_keys_from_meta_to_cpu(missing_keys + mismatched_keys, unexpected_keys, dtype, hf_quantizer)
# correctly initialize the missing (and potentially mismatched) keys
model._initialize_missing_keys(checkpoint_keys, ignore_mismatched_sizes, is_quantized)
# Set some modules to fp32 if needed
if keep_in_fp32_regex is not None:
@@ -4907,7 +4828,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
error_msgs = []
mismatched_keys = []
# Iterate on all the shards to load the weights
for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights
@@ -4915,16 +4835,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
continue
map_location = "cpu"
if low_cpu_mem_usage:
if shard_file.endswith(".safetensors"):
map_location = "meta"
elif (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
if shard_file.endswith(".safetensors"):
map_location = "meta"
elif (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
@@ -4935,41 +4854,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += _find_mismatched_keys(
model_to_load,
state_dict,
ignore_mismatched_sizes,
prefix if loading_base_model_from_task_state_dict else "",
)
if low_cpu_mem_usage:
# Skip it with fsdp on ranks other than 0
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
shard_file,
expected_keys,
reverse_key_renaming_mapping,
device_map=device_map,
disk_offload_folder=disk_offload_folder,
disk_offload_index=disk_offload_index,
cpu_offload_folder=cpu_offload_folder,
cpu_offload_index=cpu_offload_index,
hf_quantizer=hf_quantizer,
is_safetensors=is_offloaded_safetensors,
keep_in_fp32_regex=keep_in_fp32_regex,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
else:
assign_params = check_support_param_buffer_assignment(model_to_load, state_dict)
if is_deepspeed_zero3_enabled():
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_params)
else:
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params)
if is_deepspeed_zero3_enabled():
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
# Skip it with fsdp on ranks other than 0
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
shard_file,
expected_keys,
reverse_key_renaming_mapping,
device_map=device_map,
disk_offload_folder=disk_offload_folder,
disk_offload_index=disk_offload_index,
cpu_offload_folder=cpu_offload_folder,
cpu_offload_index=cpu_offload_index,
hf_quantizer=hf_quantizer,
is_safetensors=is_offloaded_safetensors,
keep_in_fp32_regex=keep_in_fp32_regex,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
del state_dict
@@ -5068,7 +4973,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
for key, (shape1, shape2) in zip(mismatched_keys, mismatched_shapes)
]
)
logger.warning(
@@ -5323,19 +5228,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"""
is_quantized = hf_quantizer is not None
# In this case we need to move everything back
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
# We only do it for the parameters, as the buffers are not initialized on the meta device by default
for key, param in self.named_parameters():
value = torch.empty_like(param, dtype=dtype, device="cpu")
_load_parameter_into_model(self, key, value)
return
model_state_dict = self.state_dict()
for key in missing_keys:
param = model_state_dict[key]
# Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them
if param.device == torch.device("meta"):
# upcast in fp32 if any
target_dtype = dtype
value = torch.empty(*param.size(), dtype=target_dtype)
value = torch.empty_like(param, dtype=dtype, device="cpu")
if (
not is_quantized
or (getattr(hf_quantizer, "requires_parameters_quantization", False))
or not hf_quantizer.check_quantized_param(self, param_value=value, param_name=key, state_dict={})
):
set_module_tensor_to_device(self, key, "cpu", value)
_load_parameter_into_model(self, key, value)
else:
hf_quantizer.create_quantized_param(self, value, key, "cpu", model_state_dict, unexpected_keys)