Update form pretrained to make TP a first class citizen (#36335)

* clean code

* oups

* fix merge

* yups

* fix if

* now you can play

* fix shape issue

* try non blocking

* fix

* updates

* up

* updates

* fix most of thetests

* update

* update

* small updates

* up

* fix the remaining bug?

* update

* rename when you read from the file

* buffer issues

* current status

* cleanup

* properly allocate dumb memory

* update a small bug

* fix colwise rep issue

* fix keep in float 32 that was keeping everything in float 32

* typo

* more fixes with keep_in_fp32_modules as we use to serach on it

* fix ROPE dtype for TP

* remove what's breaking the tests

* updates

* update and fixes

* small cleanup after merging

* allocate 2x to be safe

* style, auto

* update

* yup nit

* fix

* remove slow as fuck torch api :(

* work

* fixup

* update

* brting the fix back

* fix and update

* fixes

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* updates because some suggestions were wrong 👀

* update?

* fuck this bloated function

* typo

* fix the dumb prefix thing once and forall

* fixes here and there

* updates

* remove prints

* fix strict cases

* styel

* properly fix keys on load!

* update

* fix base model prefix issue

* style

* update

* fix all?

* remoce 1 print

* fix the final etsts

* fixup

* last nits

* fix the detach issue which cause a 2x slowdown

* fixup

* small fixes

* ultra nit

* fix

* fix

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Arthur
2025-02-26 20:12:38 +01:00
committed by GitHub
parent 981c276a02
commit 1603018e7a
36 changed files with 442 additions and 340 deletions

View File

@@ -37,9 +37,11 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVa
from zipfile import is_zipfile
import torch
import torch.distributed.tensor
from huggingface_hub import split_torch_state_dict_into_shards
from packaging import version
from torch import Tensor, nn
from torch.distributed.tensor import DTensor, Shard
from torch.distributions import constraints
from torch.nn import CrossEntropyLoss, Identity
from torch.utils.checkpoint import checkpoint
@@ -56,6 +58,7 @@ from .loss.loss_utils import LOSS_MAPPING
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
distribute_module,
find_pruneable_heads_and_indices,
id_tensor_storage,
prune_conv1d_layer,
@@ -404,9 +407,6 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi
Note: We fully disable this if we are using `deepspeed`
"""
if model_to_load.device.type == "meta":
return False
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False
@@ -514,25 +514,50 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
str_to_torch_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"U16": torch.uint16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"U32": torch.uint32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"U64": torch.uint64,
}
def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
is_quantized: bool = False,
map_location: Optional[Union[str, torch.device]] = None,
map_location: Optional[Union[str, torch.device]] = "meta",
weights_only: bool = True,
):
"""
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
Reads a `safetensor` or a `.bin` checkpoint file into `meta` if requested.
"""
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
return safe_load_file(checkpoint_file)
if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
state_dict = {}
for k in f.keys():
dtype = str_to_torch_dtype[f.get_slice(k).get_dtype()]
if map_location == "meta":
state_dict[k] = torch.empty(size=f.get_slice(k).get_shape(), dtype=dtype, device="meta")
else:
state_dict[k] = f.get_tensor(k)
return state_dict
try:
if map_location is None:
if (
@@ -677,54 +702,6 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
return shared_tensors, identical
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
if is_deepspeed_zero3_enabled():
import deepspeed
# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
return error_msgs
def find_submodule_and_param_name(model, long_key, start_prefix):
"""
A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed
@@ -774,9 +751,10 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
setattr(submodule, param_name, new_val)
@torch.no_grad()
def _load_state_dict_into_meta_model(
model,
state_dict,
model: torch.nn.Module,
state_dict: Dict[str, torch.Tensor],
start_prefix,
expected_keys,
device_map=None,
@@ -791,6 +769,7 @@ def _load_state_dict_into_meta_model(
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
device_mesh=None,
shard_file=None,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
@@ -803,167 +782,157 @@ def _load_state_dict_into_meta_model(
It also initialize tensor parallelism for each module if needed.
"""
tensor_device = None
if device_map is not None and device_map.get("", None) is not None:
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
# - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
with safe_open(shard_file, framework="pt", device=tensor_device) as file_pointer:
error_msgs = []
error_msgs = []
is_quantized = hf_quantizer is not None
is_quantized = hf_quantizer is not None
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
# we need this later to initialize tensor parallelism
if device_mesh is not None:
full_tp_plan = model.config.base_model_tp_plan
for submodule in model.modules():
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
for param_name, param in state_dict.items():
if param_name not in expected_keys:
continue
if param_name.startswith(start_prefix):
param_name = param_name[len(start_prefix) :]
module_name = param_name
set_module_kwargs = {}
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn:
if (
keep_in_fp32_modules is not None
and any(
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
)
and dtype == torch.float16
):
param = param.to(torch.float32)
# For backward compatibility with older versions of `accelerate`
# TODO: @sgugger replace this check with version check at the next `accelerate` release
if "dtype" in list(inspect.signature(set_module_tensor_to_device).parameters):
set_module_kwargs["dtype"] = torch.float32
else:
param = param.to(dtype)
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
splits = param_name.split(".")
for split in splits:
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
old_param = getattr(old_param, split, None)
if old_param is None:
break
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype)
if old_param.is_contiguous():
param = param.contiguous()
set_module_kwargs["value"] = param
if device_map is None:
param_device = "cpu"
else:
# find next higher level module that is defined in device_map:
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
if module_name == "" and "" not in device_map:
# TODO: group all errors and raise at the end.
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]
if param_device == "disk":
if not is_safetensors:
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif (
not is_quantized
or (not hf_quantizer.requires_parameters_quantization)
or (
not hf_quantizer.check_quantized_param(
model, param, param_name, state_dict, param_device=param_device, device_map=device_map
)
)
):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
# For backward compatibility with older versions of `accelerate` and for non-quantized params
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else:
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU
# in comparison to the sharded model across GPUs.
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, tensor_name = get_module_from_name(model, param_name)
value = getattr(module, tensor_name)
param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
val_kwargs = {}
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return
# In this case, let's parallelize the modules!
# we need this later to initialize tensor parallelism
if device_mesh is not None:
# Immediate parent
split_parent_module_name = param_name.split(".")[:-1]
parent_module_name = ".".join(split_parent_module_name)
parent_module = model
for name in split_parent_module_name:
parent_module = getattr(parent_module, name)
full_tp_plan = model.config.base_model_tp_plan
for submodule in model.modules():
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
# Check if we are part of the tp_plan
current_module_plan = None
for param, plan in full_tp_plan.items():
# "*" are a placeholder for layer indices, so we replace them by "[0-9]+" in the regex pattern
pattern = param.replace("*", "[0-9]+")
if re.search(pattern, parent_module_name):
current_module_plan = plan
break
for serialized_param_name, empty_param in state_dict.items():
# param_name is the raw, serialized name
# new_param_name is the model's equivalent
module_name, _ = model.rename_key(serialized_param_name)
if module_name not in expected_keys:
continue
layer, param_type = module_name.rsplit(".", 1)
# We can only apply the tp_plan after all parameters of the current module have been correctly initialized (e.g.
# if we have bias, we need both `weights` and `bias` of a nn.Linear to be initialized)
process_device = list(device_map.values())[0]
all_module_parameters_initialized = all(
m.device == process_device for m in parent_module.parameters(recurse=False)
) and all(m.device == process_device for m in parent_module.buffers(recurse=False))
if current_module_plan is not None and all_module_parameters_initialized:
torch.distributed.tensor.parallel.parallelize_module(
parent_module,
device_mesh=device_mesh,
parallelize_plan=translate_to_torch_parallel_style(current_module_plan),
)
# param name needs to stay untouched as it's in the file
param = file_pointer.get_slice(serialized_param_name)
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
param_casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
if (
keep_in_fp32_modules is not None
and keep_in_fp32_modules.search(module_name)
and dtype == torch.float16
):
param_casting_dtype = torch.float32
else:
param_casting_dtype = dtype
if device_mesh is not None: # In this case, the param is already on the correct device!
try:
module_to_tp: torch.nn.Module = model.get_submodule(layer)
except Exception:
raise ValueError(
"The config tp plan is wrong because the layer is not a liner layer, nor an embedding"
)
current_module_plan = None
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
if plan := re.search(full_tp_plan_, module_name):
match = re.sub("[0-9]+", "*", plan[0])
current_module_plan = full_tp_plan[match]
if current_module_plan is not None:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
rank = tensor_device
row, col = empty_param.shape
if "rowwise" == current_module_plan:
param = param[:, rank * (col // device_mesh.size()) : (rank + 1) * (col // device_mesh.size())]
shard = Shard(1)
tp_layer.desired_input_layouts = (Shard(-1),)
elif "colwise" == current_module_plan:
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
shard = Shard(0)
else:
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
shard = Shard(0)
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
param = param.to(param_casting_dtype)
local_parameter = DTensor.from_local(
param,
device_mesh=device_mesh,
placements=[shard] * device_mesh.ndim,
)
if isinstance(module_to_tp.weight, nn.Parameter):
local_parameter = torch.nn.Parameter(local_parameter)
module_to_tp.weight = local_parameter
input_fn = partial(
tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts
)
output_fn = partial(
tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output
)
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
else:
module_to_tp.load_state_dict({param_type: param[:]}, False, True)
else:
if device_map is None:
param_device = "cpu"
else:
module_name = module_name.rsplit(".", 1)[0]
device_map_regex = "|".join(device_map.keys())
module_layer = re.search(device_map_regex, module_name)
if module_name == "" or device_map_regex is None:
raise ValueError(
f"`device_map` is used, but {module_name} doesn't have any device set. {device_map}"
)
else:
param_device = device_map[module_layer.group()]
if param_device == "disk" and not is_safetensors:
offload_index = offload_weight(param[:], module_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param[:], module_name, state_dict_folder, state_dict_index)
elif (
not is_quantized
or (not hf_quantizer.requires_parameters_quantization)
or (
not hf_quantizer.check_quantized_param(
model, param, module_name, state_dict, param_device=param_device, device_map=device_map
)
)
):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
module = model.get_submodule(layer)
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
param = param[:].to(param_casting_dtype)
module.load_state_dict(
{param_type: param[:].to(param_device)},
False,
True,
)
else:
hf_quantizer.create_quantized_param(
model, param[:], module_name, param_device, state_dict, unexpected_keys
)
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU
# in comparison to the sharded model across GPUs.
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, tensor_name = get_module_from_name(model, module_name)
value = getattr(module, tensor_name)
param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
val_kwargs = {}
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
setattr(module, tensor_name, value)
return error_msgs, offload_index, state_dict_index
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)
path, name = weights_name.rsplit(".", 1)
weights_name = f"{path}.{variant}.{name}"
return weights_name
@@ -1283,6 +1252,45 @@ class ModuleUtilsMixin:
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
prefix,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys):
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{model_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(model_key.split(".")[1:])
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
if (
state_dict[checkpoint_key].shape[-1] == 1
and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
):
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
pass
else:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
# TODO (joao): remove `GenerationMixin` inheritance in v4.50
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
r"""
@@ -3227,6 +3235,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
return super().float(*args)
@classmethod
def get_init_context(
cls: Type[SpecificPreTrainedModelType],
_fast_init=True,
is_quantized=None,
_is_ds_init_called=None,
low_cpu_mem_usage=True,
):
init_contexts = [no_init_weights(_enable=_fast_init)]
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
set_zero3_state(),
] + init_contexts
elif low_cpu_mem_usage:
if not is_accelerate_available():
raise ImportError(
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
init_contexts.append(init_empty_weights())
if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())
return init_contexts
@classmethod
@restore_default_torch_dtype
def from_pretrained(
@@ -3528,12 +3565,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if tp_plan is not None and tp_plan != "auto":
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
if tp_plan is not None and device_map is not None:
raise ValueError(
"`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization."
)
# If torchrun was used, make sure to TP by default. This way people don't need to change tp or device map
if device_map == "auto" and tp_plan is None and int(os.environ.get("WORLD_SIZE", 0)):
tp_plan = "auto" # device_map = "auto" in torchrun equivalent to TP plan = AUTO!
device_map = None
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
# `device_map` pointing to the correct device
device_mesh = None
@@ -3541,7 +3582,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
if not torch.distributed.is_initialized():
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
try:
logger.warning("Tensor Parallel requires torch.distributed to be initialized first.")
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
except Exception as e:
raise EnvironmentError(
"We tried to initialize torch.distributed for you, but it failed, make"
"sure you init torch distributed in your script to use `tp_plan='auto'`"
) from e
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
@@ -4119,7 +4170,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if from_pt:
if not is_sharded and state_dict is None:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
state_dict = load_state_dict(resolved_archive_file, map_location="meta", weights_only=weights_only)
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
@@ -4205,25 +4256,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config.name_or_path = pretrained_model_name_or_path
# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
set_zero3_state(),
] + init_contexts
elif low_cpu_mem_usage:
if not is_accelerate_available():
raise ImportError(
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
init_contexts.append(init_empty_weights())
if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())
model_init_context = cls.get_init_context(_fast_init, is_quantized, _is_ds_init_called, low_cpu_mem_usage)
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
if not getattr(config, "_attn_implementation_autoset", False):
@@ -4231,7 +4264,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
)
with ContextManagers(init_contexts):
with ContextManagers(model_init_context):
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs)
@@ -4510,8 +4543,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return key, False
@classmethod
def _fix_state_dict_keys_on_load(cls, state_dict):
def rename_key(self, key):
new_key = key
if len(self.base_model_prefix) > 0:
if not hasattr(self, self.base_model_prefix) and key.startswith(self.base_model_prefix):
new_key = ".".join(key.split(".")[1:])
elif (
hasattr(self, self.base_model_prefix)
and not key.startswith(self.base_model_prefix)
and key not in self.expected_keys
):
new_key = f"{self.base_model_prefix}.{key}"
new_key, has_changed = self._fix_state_dict_key_on_load(new_key)
return new_key, has_changed
def _fix_state_dict_keys_on_load(self, state_dict):
"""Fixes state dict keys by replacing legacy parameter names with their modern equivalents.
Logs if any parameters have been renamed.
"""
@@ -4519,18 +4566,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
renamed_keys = {}
state_dict_keys = list(state_dict.keys())
for key in state_dict_keys:
new_key, has_changed = cls._fix_state_dict_key_on_load(key)
if has_changed:
state_dict[new_key] = state_dict.pop(key)
new_key, has_changed = self.rename_key(key)
state_dict[new_key] = state_dict.pop(key)
# track gamma/beta rename for logging
# track gamma/beta rename for logging
if has_changed:
if key.endswith("LayerNorm.gamma"):
renamed_keys["LayerNorm.gamma"] = (key, new_key)
elif key.endswith("LayerNorm.beta"):
renamed_keys["LayerNorm.beta"] = (key, new_key)
if renamed_keys:
warning_msg = f"A pretrained model of type `{cls.__name__}` "
warning_msg = f"A pretrained model of type `{self.__class__.__name__}` "
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.values():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
@@ -4611,7 +4658,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
original_loaded_keys = loaded_keys
loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_keys]
loaded_keys = [model._fix_state_dict_key_on_load(key)[0] for key in loaded_keys]
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
@@ -4759,11 +4806,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model.apply(model._initialize_weights)
# Set some modules to fp32 if any
if keep_in_fp32_modules == []:
keep_in_fp32_modules = None
if keep_in_fp32_modules is not None:
keep_in_fp32_modules = re.compile("|".join(keep_in_fp32_modules))
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
if keep_in_fp32_modules.search(name):
# param = param.to(torch.float32) does not work here as only in the local scope.
param.data = param.data.to(torch.float32)
param.data = param.data.to(torch.float32) # TODO @Cyrilvallez: we seem to do this twice
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
@@ -4781,51 +4831,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if device_map is not None:
device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys):
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{model_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(model_key.split(".")[1:])
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
if (
state_dict[checkpoint_key].shape[-1] == 1
and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
):
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
pass
else:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
if resolved_archive_file is not None:
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
else:
folder = None
model.expected_keys = expected_keys
if device_map is not None:
expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
caching_allocator_warmup(model, expanded_device_map, dtype)
@@ -4850,6 +4861,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
offload_index = None
error_msgs = []
if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
@@ -4860,14 +4872,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
prefix,
)
# For GGUF models `state_dict` is never set to None as the state dict is always small
if gguf_path or low_cpu_mem_usage:
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
if gguf_path or low_cpu_mem_usage and is_safetensors:
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
fixed_state_dict,
state_dict,
start_prefix,
expected_keys,
device_map=device_map,
@@ -4881,17 +4893,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
resolved_archive_file=resolved_archive_file,
)
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
# We need to read the state dict as it is meta otherwise
if resolved_archive_file is not None:
state_dict = load_state_dict(resolved_archive_file, map_location="cpu")
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
error_msgs = _load_state_dict_into_model(
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)
# at this point the state dict should be on cpu, we don't need to actually read it
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
else:
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
@@ -4945,8 +4958,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
prefix,
)
if low_cpu_mem_usage:
if low_cpu_mem_usage and shard_file.endswith(".safetensors"):
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
@@ -4954,10 +4968,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
fixed_state_dict,
state_dict,
start_prefix,
expected_keys,
device_map=device_map,
@@ -4971,19 +4984,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
shard_file=shard_file,
)
error_msgs += new_error_msgs
else:
state_dict = load_state_dict(shard_file, map_location="cpu", weights_only=weights_only)
# Sharded checkpoint or whole but low_cpu_mem_usage==True
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
error_msgs += _load_state_dict_into_model(
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
# force memory release
del state_dict
gc.collect()
@@ -5257,6 +5269,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Calling `from_pretrained(..., tp_plan="auto")` is prefered, and will parallelize module-by-module during initialization,
so that the expected per-device memory spike at loading time is not larger than the final model size on each device.
Tensor parallelize the model across the given device mesh. This function is a helper to be called after the model
was already loaded in memory, note however that this means that each process will first initialize the whole model,
then parallelize it accross devices. Thus there is a huge waste of GPU memory, and this can lead to OOM at loading time.
Args:
device_mesh (`torch.distributed.DeviceMesh`):
@@ -5825,12 +5840,12 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
param = model.get_parameter(param_name)
except AttributeError:
param = model.get_buffer(param_name)
parameter_count[device] += math.prod(param.shape)
parameter_count[device] += int(math.prod(param.shape) * 2)
dtype = dtype if dtype is not None else torch.float32
# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items():
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
_ = torch.empty(int(param_count), dtype=dtype, device=device, requires_grad=False)
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):