Update form pretrained to make TP a first class citizen (#36335)
* clean code * oups * fix merge * yups * fix if * now you can play * fix shape issue * try non blocking * fix * updates * up * updates * fix most of thetests * update * update * small updates * up * fix the remaining bug? * update * rename when you read from the file * buffer issues * current status * cleanup * properly allocate dumb memory * update a small bug * fix colwise rep issue * fix keep in float 32 that was keeping everything in float 32 * typo * more fixes with keep_in_fp32_modules as we use to serach on it * fix ROPE dtype for TP * remove what's breaking the tests * updates * update and fixes * small cleanup after merging * allocate 2x to be safe * style, auto * update * yup nit * fix * remove slow as fuck torch api :( * work * fixup * update * brting the fix back * fix and update * fixes Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * updates because some suggestions were wrong 👀 * update? * fuck this bloated function * typo * fix the dumb prefix thing once and forall * fixes here and there * updates * remove prints * fix strict cases * styel * properly fix keys on load! * update * fix base model prefix issue * style * update * fix all? * remoce 1 print * fix the final etsts * fixup * last nits * fix the detach issue which cause a 2x slowdown * fixup * small fixes * ultra nit * fix * fix --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -37,9 +37,11 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVa
|
||||
from zipfile import is_zipfile
|
||||
|
||||
import torch
|
||||
import torch.distributed.tensor
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from packaging import version
|
||||
from torch import Tensor, nn
|
||||
from torch.distributed.tensor import DTensor, Shard
|
||||
from torch.distributions import constraints
|
||||
from torch.nn import CrossEntropyLoss, Identity
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
@@ -56,6 +58,7 @@ from .loss.loss_utils import LOSS_MAPPING
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
apply_chunking_to_forward,
|
||||
distribute_module,
|
||||
find_pruneable_heads_and_indices,
|
||||
id_tensor_storage,
|
||||
prune_conv1d_layer,
|
||||
@@ -404,9 +407,6 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi
|
||||
|
||||
Note: We fully disable this if we are using `deepspeed`
|
||||
"""
|
||||
if model_to_load.device.type == "meta":
|
||||
return False
|
||||
|
||||
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
|
||||
return False
|
||||
|
||||
@@ -514,25 +514,50 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
|
||||
str_to_torch_dtype = {
|
||||
"BOOL": torch.bool,
|
||||
"U8": torch.uint8,
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"U16": torch.uint16,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I32": torch.int32,
|
||||
"U32": torch.uint32,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
"I64": torch.int64,
|
||||
"U64": torch.uint64,
|
||||
}
|
||||
|
||||
|
||||
def load_state_dict(
|
||||
checkpoint_file: Union[str, os.PathLike],
|
||||
is_quantized: bool = False,
|
||||
map_location: Optional[Union[str, torch.device]] = None,
|
||||
map_location: Optional[Union[str, torch.device]] = "meta",
|
||||
weights_only: bool = True,
|
||||
):
|
||||
"""
|
||||
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
||||
Reads a `safetensor` or a `.bin` checkpoint file into `meta` if requested.
|
||||
"""
|
||||
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
|
||||
# Check format of the archive
|
||||
with safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
|
||||
raise OSError(
|
||||
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
||||
"you save your model with the `save_pretrained` method."
|
||||
)
|
||||
return safe_load_file(checkpoint_file)
|
||||
|
||||
if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
|
||||
raise OSError(
|
||||
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
||||
"you save your model with the `save_pretrained` method."
|
||||
)
|
||||
state_dict = {}
|
||||
for k in f.keys():
|
||||
dtype = str_to_torch_dtype[f.get_slice(k).get_dtype()]
|
||||
if map_location == "meta":
|
||||
state_dict[k] = torch.empty(size=f.get_slice(k).get_shape(), dtype=dtype, device="meta")
|
||||
else:
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
return state_dict
|
||||
|
||||
try:
|
||||
if map_location is None:
|
||||
if (
|
||||
@@ -677,54 +702,6 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
|
||||
return shared_tensors, identical
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
error_msgs = []
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
|
||||
|
||||
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
||||
# state_dict
|
||||
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
# In sharded models, each shard has only part of the full state_dict, so only gather
|
||||
# parameters that are in the current state_dict.
|
||||
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
|
||||
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
|
||||
if len(params_to_gather) > 0:
|
||||
# because zero3 puts placeholders in model params, this context
|
||||
# manager gathers (unpartitions) the params of the current layer, then loads from
|
||||
# the state dict and then re-partitions them again
|
||||
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
module._load_from_state_dict(*args)
|
||||
else:
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
|
||||
|
||||
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
|
||||
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
|
||||
# it's safe to delete it.
|
||||
del state_dict
|
||||
|
||||
return error_msgs
|
||||
|
||||
|
||||
def find_submodule_and_param_name(model, long_key, start_prefix):
|
||||
"""
|
||||
A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed
|
||||
@@ -774,9 +751,10 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
|
||||
setattr(submodule, param_name, new_val)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _load_state_dict_into_meta_model(
|
||||
model,
|
||||
state_dict,
|
||||
model: torch.nn.Module,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=None,
|
||||
@@ -791,6 +769,7 @@ def _load_state_dict_into_meta_model(
|
||||
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
|
||||
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
|
||||
device_mesh=None,
|
||||
shard_file=None,
|
||||
):
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
@@ -803,167 +782,157 @@ def _load_state_dict_into_meta_model(
|
||||
It also initialize tensor parallelism for each module if needed.
|
||||
|
||||
"""
|
||||
tensor_device = None
|
||||
if device_map is not None and device_map.get("", None) is not None:
|
||||
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
|
||||
|
||||
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
|
||||
# - deepspeed zero 3 support
|
||||
# - need to copy metadata if any - see _load_state_dict_into_model
|
||||
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
|
||||
with safe_open(shard_file, framework="pt", device=tensor_device) as file_pointer:
|
||||
error_msgs = []
|
||||
|
||||
error_msgs = []
|
||||
is_quantized = hf_quantizer is not None
|
||||
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||
|
||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||
|
||||
# we need this later to initialize tensor parallelism
|
||||
if device_mesh is not None:
|
||||
full_tp_plan = model.config.base_model_tp_plan
|
||||
for submodule in model.modules():
|
||||
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
|
||||
|
||||
for param_name, param in state_dict.items():
|
||||
if param_name not in expected_keys:
|
||||
continue
|
||||
|
||||
if param_name.startswith(start_prefix):
|
||||
param_name = param_name[len(start_prefix) :]
|
||||
|
||||
module_name = param_name
|
||||
set_module_kwargs = {}
|
||||
|
||||
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
|
||||
if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn:
|
||||
if (
|
||||
keep_in_fp32_modules is not None
|
||||
and any(
|
||||
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
|
||||
)
|
||||
and dtype == torch.float16
|
||||
):
|
||||
param = param.to(torch.float32)
|
||||
|
||||
# For backward compatibility with older versions of `accelerate`
|
||||
# TODO: @sgugger replace this check with version check at the next `accelerate` release
|
||||
if "dtype" in list(inspect.signature(set_module_tensor_to_device).parameters):
|
||||
set_module_kwargs["dtype"] = torch.float32
|
||||
else:
|
||||
param = param.to(dtype)
|
||||
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||
|
||||
old_param = model
|
||||
splits = param_name.split(".")
|
||||
for split in splits:
|
||||
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
|
||||
old_param = getattr(old_param, split, None)
|
||||
if old_param is None:
|
||||
break
|
||||
|
||||
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
||||
old_param = None
|
||||
|
||||
if old_param is not None:
|
||||
if dtype is None:
|
||||
param = param.to(old_param.dtype)
|
||||
|
||||
if old_param.is_contiguous():
|
||||
param = param.contiguous()
|
||||
|
||||
set_module_kwargs["value"] = param
|
||||
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
else:
|
||||
# find next higher level module that is defined in device_map:
|
||||
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
|
||||
while len(module_name) > 0 and module_name not in device_map:
|
||||
module_name = ".".join(module_name.split(".")[:-1])
|
||||
if module_name == "" and "" not in device_map:
|
||||
# TODO: group all errors and raise at the end.
|
||||
raise ValueError(f"{param_name} doesn't have any device set.")
|
||||
param_device = device_map[module_name]
|
||||
|
||||
if param_device == "disk":
|
||||
if not is_safetensors:
|
||||
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
|
||||
elif param_device == "cpu" and state_dict_index is not None:
|
||||
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
|
||||
elif (
|
||||
not is_quantized
|
||||
or (not hf_quantizer.requires_parameters_quantization)
|
||||
or (
|
||||
not hf_quantizer.check_quantized_param(
|
||||
model, param, param_name, state_dict, param_device=param_device, device_map=device_map
|
||||
)
|
||||
)
|
||||
):
|
||||
if is_fsdp_enabled():
|
||||
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
||||
|
||||
# For backward compatibility with older versions of `accelerate` and for non-quantized params
|
||||
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
|
||||
else:
|
||||
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
|
||||
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
|
||||
# and then cast it to CPU to avoid excessive memory usage on each GPU
|
||||
# in comparison to the sharded model across GPUs.
|
||||
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
value = getattr(module, tensor_name)
|
||||
param_to = "cpu"
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
||||
param_to = "meta"
|
||||
val_kwargs = {}
|
||||
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
|
||||
val_kwargs["requires_grad"] = False
|
||||
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
|
||||
setattr(module, tensor_name, value)
|
||||
# TODO: consider removing used param_parts from state_dict before return
|
||||
|
||||
# In this case, let's parallelize the modules!
|
||||
# we need this later to initialize tensor parallelism
|
||||
if device_mesh is not None:
|
||||
# Immediate parent
|
||||
split_parent_module_name = param_name.split(".")[:-1]
|
||||
parent_module_name = ".".join(split_parent_module_name)
|
||||
parent_module = model
|
||||
for name in split_parent_module_name:
|
||||
parent_module = getattr(parent_module, name)
|
||||
full_tp_plan = model.config.base_model_tp_plan
|
||||
for submodule in model.modules():
|
||||
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
|
||||
|
||||
# Check if we are part of the tp_plan
|
||||
current_module_plan = None
|
||||
for param, plan in full_tp_plan.items():
|
||||
# "*" are a placeholder for layer indices, so we replace them by "[0-9]+" in the regex pattern
|
||||
pattern = param.replace("*", "[0-9]+")
|
||||
if re.search(pattern, parent_module_name):
|
||||
current_module_plan = plan
|
||||
break
|
||||
for serialized_param_name, empty_param in state_dict.items():
|
||||
# param_name is the raw, serialized name
|
||||
# new_param_name is the model's equivalent
|
||||
module_name, _ = model.rename_key(serialized_param_name)
|
||||
if module_name not in expected_keys:
|
||||
continue
|
||||
layer, param_type = module_name.rsplit(".", 1)
|
||||
|
||||
# We can only apply the tp_plan after all parameters of the current module have been correctly initialized (e.g.
|
||||
# if we have bias, we need both `weights` and `bias` of a nn.Linear to be initialized)
|
||||
process_device = list(device_map.values())[0]
|
||||
all_module_parameters_initialized = all(
|
||||
m.device == process_device for m in parent_module.parameters(recurse=False)
|
||||
) and all(m.device == process_device for m in parent_module.buffers(recurse=False))
|
||||
if current_module_plan is not None and all_module_parameters_initialized:
|
||||
torch.distributed.tensor.parallel.parallelize_module(
|
||||
parent_module,
|
||||
device_mesh=device_mesh,
|
||||
parallelize_plan=translate_to_torch_parallel_style(current_module_plan),
|
||||
)
|
||||
# param name needs to stay untouched as it's in the file
|
||||
param = file_pointer.get_slice(serialized_param_name)
|
||||
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
param_casting_dtype = None
|
||||
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
|
||||
if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
||||
if (
|
||||
keep_in_fp32_modules is not None
|
||||
and keep_in_fp32_modules.search(module_name)
|
||||
and dtype == torch.float16
|
||||
):
|
||||
param_casting_dtype = torch.float32
|
||||
else:
|
||||
param_casting_dtype = dtype
|
||||
|
||||
if device_mesh is not None: # In this case, the param is already on the correct device!
|
||||
try:
|
||||
module_to_tp: torch.nn.Module = model.get_submodule(layer)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"The config tp plan is wrong because the layer is not a liner layer, nor an embedding"
|
||||
)
|
||||
current_module_plan = None
|
||||
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
|
||||
if plan := re.search(full_tp_plan_, module_name):
|
||||
match = re.sub("[0-9]+", "*", plan[0])
|
||||
current_module_plan = full_tp_plan[match]
|
||||
|
||||
if current_module_plan is not None:
|
||||
tp_layer = translate_to_torch_parallel_style(current_module_plan)
|
||||
rank = tensor_device
|
||||
row, col = empty_param.shape
|
||||
if "rowwise" == current_module_plan:
|
||||
param = param[:, rank * (col // device_mesh.size()) : (rank + 1) * (col // device_mesh.size())]
|
||||
shard = Shard(1)
|
||||
tp_layer.desired_input_layouts = (Shard(-1),)
|
||||
elif "colwise" == current_module_plan:
|
||||
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
|
||||
shard = Shard(0)
|
||||
else:
|
||||
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
|
||||
shard = Shard(0)
|
||||
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
|
||||
param = param.to(param_casting_dtype)
|
||||
local_parameter = DTensor.from_local(
|
||||
param,
|
||||
device_mesh=device_mesh,
|
||||
placements=[shard] * device_mesh.ndim,
|
||||
)
|
||||
if isinstance(module_to_tp.weight, nn.Parameter):
|
||||
local_parameter = torch.nn.Parameter(local_parameter)
|
||||
module_to_tp.weight = local_parameter
|
||||
input_fn = partial(
|
||||
tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts
|
||||
)
|
||||
output_fn = partial(
|
||||
tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output
|
||||
)
|
||||
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
|
||||
else:
|
||||
module_to_tp.load_state_dict({param_type: param[:]}, False, True)
|
||||
|
||||
else:
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
else:
|
||||
module_name = module_name.rsplit(".", 1)[0]
|
||||
device_map_regex = "|".join(device_map.keys())
|
||||
module_layer = re.search(device_map_regex, module_name)
|
||||
if module_name == "" or device_map_regex is None:
|
||||
raise ValueError(
|
||||
f"`device_map` is used, but {module_name} doesn't have any device set. {device_map}"
|
||||
)
|
||||
else:
|
||||
param_device = device_map[module_layer.group()]
|
||||
|
||||
if param_device == "disk" and not is_safetensors:
|
||||
offload_index = offload_weight(param[:], module_name, offload_folder, offload_index)
|
||||
elif param_device == "cpu" and state_dict_index is not None:
|
||||
state_dict_index = offload_weight(param[:], module_name, state_dict_folder, state_dict_index)
|
||||
elif (
|
||||
not is_quantized
|
||||
or (not hf_quantizer.requires_parameters_quantization)
|
||||
or (
|
||||
not hf_quantizer.check_quantized_param(
|
||||
model, param, module_name, state_dict, param_device=param_device, device_map=device_map
|
||||
)
|
||||
)
|
||||
):
|
||||
if is_fsdp_enabled():
|
||||
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
||||
module = model.get_submodule(layer)
|
||||
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
|
||||
param = param[:].to(param_casting_dtype)
|
||||
module.load_state_dict(
|
||||
{param_type: param[:].to(param_device)},
|
||||
False,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
hf_quantizer.create_quantized_param(
|
||||
model, param[:], module_name, param_device, state_dict, unexpected_keys
|
||||
)
|
||||
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
|
||||
# and then cast it to CPU to avoid excessive memory usage on each GPU
|
||||
# in comparison to the sharded model across GPUs.
|
||||
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
||||
module, tensor_name = get_module_from_name(model, module_name)
|
||||
value = getattr(module, tensor_name)
|
||||
param_to = "cpu"
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
||||
param_to = "meta"
|
||||
val_kwargs = {}
|
||||
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
|
||||
val_kwargs["requires_grad"] = False
|
||||
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
|
||||
setattr(module, tensor_name, value)
|
||||
|
||||
return error_msgs, offload_index, state_dict_index
|
||||
|
||||
|
||||
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
if variant is not None:
|
||||
splits = weights_name.split(".")
|
||||
splits = splits[:-1] + [variant] + splits[-1:]
|
||||
weights_name = ".".join(splits)
|
||||
|
||||
path, name = weights_name.rsplit(".", 1)
|
||||
weights_name = f"{path}.{variant}.{name}"
|
||||
return weights_name
|
||||
|
||||
|
||||
@@ -1283,6 +1252,45 @@ class ModuleUtilsMixin:
|
||||
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
|
||||
|
||||
|
||||
def _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
original_loaded_keys,
|
||||
add_prefix_to_model,
|
||||
remove_prefix_from_model,
|
||||
ignore_mismatched_sizes,
|
||||
prefix,
|
||||
):
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys):
|
||||
# If the checkpoint is sharded, we may not have the key here.
|
||||
if checkpoint_key not in state_dict:
|
||||
continue
|
||||
if remove_prefix_from_model:
|
||||
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
||||
model_key = f"{prefix}.{model_key}"
|
||||
elif add_prefix_to_model:
|
||||
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
||||
model_key = ".".join(model_key.split(".")[1:])
|
||||
|
||||
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
|
||||
if (
|
||||
state_dict[checkpoint_key].shape[-1] == 1
|
||||
and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
|
||||
):
|
||||
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
|
||||
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
|
||||
pass
|
||||
else:
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
|
||||
|
||||
# TODO (joao): remove `GenerationMixin` inheritance in v4.50
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
@@ -3227,6 +3235,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
else:
|
||||
return super().float(*args)
|
||||
|
||||
@classmethod
|
||||
def get_init_context(
|
||||
cls: Type[SpecificPreTrainedModelType],
|
||||
_fast_init=True,
|
||||
is_quantized=None,
|
||||
_is_ds_init_called=None,
|
||||
low_cpu_mem_usage=True,
|
||||
):
|
||||
init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
|
||||
import deepspeed
|
||||
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
init_contexts = [
|
||||
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
|
||||
set_zero3_state(),
|
||||
] + init_contexts
|
||||
elif low_cpu_mem_usage:
|
||||
if not is_accelerate_available():
|
||||
raise ImportError(
|
||||
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||
)
|
||||
init_contexts.append(init_empty_weights())
|
||||
|
||||
if is_deepspeed_zero3_enabled() and is_quantized:
|
||||
init_contexts.append(set_quantized_state())
|
||||
return init_contexts
|
||||
|
||||
@classmethod
|
||||
@restore_default_torch_dtype
|
||||
def from_pretrained(
|
||||
@@ -3528,12 +3565,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if tp_plan is not None and tp_plan != "auto":
|
||||
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
|
||||
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
|
||||
|
||||
if tp_plan is not None and device_map is not None:
|
||||
raise ValueError(
|
||||
"`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization."
|
||||
)
|
||||
|
||||
# If torchrun was used, make sure to TP by default. This way people don't need to change tp or device map
|
||||
if device_map == "auto" and tp_plan is None and int(os.environ.get("WORLD_SIZE", 0)):
|
||||
tp_plan = "auto" # device_map = "auto" in torchrun equivalent to TP plan = AUTO!
|
||||
device_map = None
|
||||
|
||||
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
|
||||
# `device_map` pointing to the correct device
|
||||
device_mesh = None
|
||||
@@ -3541,7 +3582,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if not is_torch_greater_or_equal("2.5"):
|
||||
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
|
||||
if not torch.distributed.is_initialized():
|
||||
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
|
||||
try:
|
||||
logger.warning("Tensor Parallel requires torch.distributed to be initialized first.")
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
except Exception as e:
|
||||
raise EnvironmentError(
|
||||
"We tried to initialize torch.distributed for you, but it failed, make"
|
||||
"sure you init torch distributed in your script to use `tp_plan='auto'`"
|
||||
) from e
|
||||
|
||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||
device_type = torch._C._get_accelerator().type
|
||||
@@ -4119,7 +4170,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if from_pt:
|
||||
if not is_sharded and state_dict is None:
|
||||
# Time to load the checkpoint
|
||||
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
|
||||
state_dict = load_state_dict(resolved_archive_file, map_location="meta", weights_only=weights_only)
|
||||
|
||||
# set dtype to instantiate the model under:
|
||||
# 1. If torch_dtype is not None, we use that dtype
|
||||
@@ -4205,25 +4256,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
# Instantiate model.
|
||||
init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
|
||||
import deepspeed
|
||||
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
init_contexts = [
|
||||
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
|
||||
set_zero3_state(),
|
||||
] + init_contexts
|
||||
elif low_cpu_mem_usage:
|
||||
if not is_accelerate_available():
|
||||
raise ImportError(
|
||||
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||
)
|
||||
init_contexts.append(init_empty_weights())
|
||||
|
||||
if is_deepspeed_zero3_enabled() and is_quantized:
|
||||
init_contexts.append(set_quantized_state())
|
||||
model_init_context = cls.get_init_context(_fast_init, is_quantized, _is_ds_init_called, low_cpu_mem_usage)
|
||||
|
||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
||||
if not getattr(config, "_attn_implementation_autoset", False):
|
||||
@@ -4231,7 +4264,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
|
||||
)
|
||||
|
||||
with ContextManagers(init_contexts):
|
||||
with ContextManagers(model_init_context):
|
||||
# Let's make sure we don't run the init function of buffer modules
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
@@ -4510,8 +4543,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
return key, False
|
||||
|
||||
@classmethod
|
||||
def _fix_state_dict_keys_on_load(cls, state_dict):
|
||||
def rename_key(self, key):
|
||||
new_key = key
|
||||
if len(self.base_model_prefix) > 0:
|
||||
if not hasattr(self, self.base_model_prefix) and key.startswith(self.base_model_prefix):
|
||||
new_key = ".".join(key.split(".")[1:])
|
||||
elif (
|
||||
hasattr(self, self.base_model_prefix)
|
||||
and not key.startswith(self.base_model_prefix)
|
||||
and key not in self.expected_keys
|
||||
):
|
||||
new_key = f"{self.base_model_prefix}.{key}"
|
||||
|
||||
new_key, has_changed = self._fix_state_dict_key_on_load(new_key)
|
||||
return new_key, has_changed
|
||||
|
||||
def _fix_state_dict_keys_on_load(self, state_dict):
|
||||
"""Fixes state dict keys by replacing legacy parameter names with their modern equivalents.
|
||||
Logs if any parameters have been renamed.
|
||||
"""
|
||||
@@ -4519,18 +4566,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
renamed_keys = {}
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
for key in state_dict_keys:
|
||||
new_key, has_changed = cls._fix_state_dict_key_on_load(key)
|
||||
if has_changed:
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
new_key, has_changed = self.rename_key(key)
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
# track gamma/beta rename for logging
|
||||
# track gamma/beta rename for logging
|
||||
if has_changed:
|
||||
if key.endswith("LayerNorm.gamma"):
|
||||
renamed_keys["LayerNorm.gamma"] = (key, new_key)
|
||||
elif key.endswith("LayerNorm.beta"):
|
||||
renamed_keys["LayerNorm.beta"] = (key, new_key)
|
||||
|
||||
if renamed_keys:
|
||||
warning_msg = f"A pretrained model of type `{cls.__name__}` "
|
||||
warning_msg = f"A pretrained model of type `{self.__class__.__name__}` "
|
||||
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
|
||||
for old_key, new_key in renamed_keys.values():
|
||||
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
|
||||
@@ -4611,7 +4658,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
|
||||
|
||||
original_loaded_keys = loaded_keys
|
||||
loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_keys]
|
||||
loaded_keys = [model._fix_state_dict_key_on_load(key)[0] for key in loaded_keys]
|
||||
|
||||
if len(prefix) > 0:
|
||||
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
||||
@@ -4759,11 +4806,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model.apply(model._initialize_weights)
|
||||
|
||||
# Set some modules to fp32 if any
|
||||
if keep_in_fp32_modules == []:
|
||||
keep_in_fp32_modules = None
|
||||
if keep_in_fp32_modules is not None:
|
||||
keep_in_fp32_modules = re.compile("|".join(keep_in_fp32_modules))
|
||||
for name, param in model.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||
if keep_in_fp32_modules.search(name):
|
||||
# param = param.to(torch.float32) does not work here as only in the local scope.
|
||||
param.data = param.data.to(torch.float32)
|
||||
param.data = param.data.to(torch.float32) # TODO @Cyrilvallez: we seem to do this twice
|
||||
|
||||
# Make sure we are able to load base models as well as derived models (with heads)
|
||||
start_prefix = ""
|
||||
@@ -4781,51 +4831,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if device_map is not None:
|
||||
device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
|
||||
|
||||
def _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
original_loaded_keys,
|
||||
add_prefix_to_model,
|
||||
remove_prefix_from_model,
|
||||
ignore_mismatched_sizes,
|
||||
):
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys):
|
||||
# If the checkpoint is sharded, we may not have the key here.
|
||||
if checkpoint_key not in state_dict:
|
||||
continue
|
||||
if remove_prefix_from_model:
|
||||
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
||||
model_key = f"{prefix}.{model_key}"
|
||||
elif add_prefix_to_model:
|
||||
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
||||
model_key = ".".join(model_key.split(".")[1:])
|
||||
|
||||
if (
|
||||
model_key in model_state_dict
|
||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||
):
|
||||
if (
|
||||
state_dict[checkpoint_key].shape[-1] == 1
|
||||
and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
|
||||
):
|
||||
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
|
||||
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
|
||||
pass
|
||||
else:
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
|
||||
if resolved_archive_file is not None:
|
||||
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
|
||||
else:
|
||||
folder = None
|
||||
|
||||
model.expected_keys = expected_keys
|
||||
if device_map is not None:
|
||||
expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
|
||||
caching_allocator_warmup(model, expanded_device_map, dtype)
|
||||
@@ -4850,6 +4861,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
else:
|
||||
offload_index = None
|
||||
|
||||
error_msgs = []
|
||||
if state_dict is not None:
|
||||
# Whole checkpoint
|
||||
mismatched_keys = _find_mismatched_keys(
|
||||
@@ -4860,14 +4872,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
add_prefix_to_model,
|
||||
remove_prefix_from_model,
|
||||
ignore_mismatched_sizes,
|
||||
prefix,
|
||||
)
|
||||
|
||||
# For GGUF models `state_dict` is never set to None as the state dict is always small
|
||||
if gguf_path or low_cpu_mem_usage:
|
||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
if gguf_path or low_cpu_mem_usage and is_safetensors:
|
||||
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
fixed_state_dict,
|
||||
state_dict,
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=device_map,
|
||||
@@ -4881,17 +4893,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
resolved_archive_file=resolved_archive_file,
|
||||
)
|
||||
else:
|
||||
# Sharded checkpoint or whole but low_cpu_mem_usage==True
|
||||
# We need to read the state dict as it is meta otherwise
|
||||
if resolved_archive_file is not None:
|
||||
state_dict = load_state_dict(resolved_archive_file, map_location="cpu")
|
||||
assign_to_params_buffers = check_support_param_buffer_assignment(
|
||||
model_to_load, state_dict, start_prefix
|
||||
)
|
||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
error_msgs = _load_state_dict_into_model(
|
||||
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||
)
|
||||
|
||||
# at this point the state dict should be on cpu, we don't need to actually read it
|
||||
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
|
||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||
else:
|
||||
# This should always be a list but, just to be sure.
|
||||
if not isinstance(resolved_archive_file, list):
|
||||
@@ -4945,8 +4958,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
add_prefix_to_model,
|
||||
remove_prefix_from_model,
|
||||
ignore_mismatched_sizes,
|
||||
prefix,
|
||||
)
|
||||
if low_cpu_mem_usage:
|
||||
if low_cpu_mem_usage and shard_file.endswith(".safetensors"):
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
||||
for key, param in model_to_load.state_dict().items():
|
||||
if param.device == torch.device("meta"):
|
||||
@@ -4954,10 +4968,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
fixed_state_dict,
|
||||
state_dict,
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=device_map,
|
||||
@@ -4971,19 +4984,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
shard_file=shard_file,
|
||||
)
|
||||
error_msgs += new_error_msgs
|
||||
else:
|
||||
state_dict = load_state_dict(shard_file, map_location="cpu", weights_only=weights_only)
|
||||
# Sharded checkpoint or whole but low_cpu_mem_usage==True
|
||||
if assign_to_params_buffers is None:
|
||||
assign_to_params_buffers = check_support_param_buffer_assignment(
|
||||
model_to_load, state_dict, start_prefix
|
||||
)
|
||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
error_msgs += _load_state_dict_into_model(
|
||||
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||
)
|
||||
|
||||
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
|
||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||
# force memory release
|
||||
del state_dict
|
||||
gc.collect()
|
||||
@@ -5257,6 +5269,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
Calling `from_pretrained(..., tp_plan="auto")` is prefered, and will parallelize module-by-module during initialization,
|
||||
so that the expected per-device memory spike at loading time is not larger than the final model size on each device.
|
||||
Tensor parallelize the model across the given device mesh. This function is a helper to be called after the model
|
||||
was already loaded in memory, note however that this means that each process will first initialize the whole model,
|
||||
then parallelize it accross devices. Thus there is a huge waste of GPU memory, and this can lead to OOM at loading time.
|
||||
|
||||
Args:
|
||||
device_mesh (`torch.distributed.DeviceMesh`):
|
||||
@@ -5825,12 +5840,12 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
||||
param = model.get_parameter(param_name)
|
||||
except AttributeError:
|
||||
param = model.get_buffer(param_name)
|
||||
parameter_count[device] += math.prod(param.shape)
|
||||
parameter_count[device] += int(math.prod(param.shape) * 2)
|
||||
|
||||
dtype = dtype if dtype is not None else torch.float32
|
||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||
for device, param_count in parameter_count.items():
|
||||
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
|
||||
_ = torch.empty(int(param_count), dtype=dtype, device=device, requires_grad=False)
|
||||
|
||||
|
||||
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
|
||||
|
||||
Reference in New Issue
Block a user