fix torch_dtype, contiguous, and load_state_dict regression (#36512)
* fix regression * fix param * fix load_state_dict * style * better fix for module * fix tests * quick fix for now * rm print
This commit is contained in:
@@ -67,6 +67,7 @@ from .pytorch_utils import ( # noqa: F401
|
||||
translate_to_torch_parallel_style,
|
||||
)
|
||||
from .quantizers import AutoHfQuantizer, HfQuantizer
|
||||
from .quantizers.quantizers_utils import get_module_from_name
|
||||
from .safetensors_conversion import auto_conversion
|
||||
from .utils import (
|
||||
ACCELERATE_MIN_VERSION,
|
||||
@@ -536,11 +537,11 @@ str_to_torch_dtype = {
|
||||
def load_state_dict(
|
||||
checkpoint_file: Union[str, os.PathLike],
|
||||
is_quantized: bool = False,
|
||||
map_location: Optional[Union[str, torch.device]] = "meta",
|
||||
map_location: Optional[Union[str, torch.device]] = "cpu",
|
||||
weights_only: bool = True,
|
||||
):
|
||||
"""
|
||||
Reads a `safetensor` or a `.bin` checkpoint file into `meta` if requested.
|
||||
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
|
||||
"""
|
||||
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
|
||||
with safe_open(checkpoint_file, framework="pt") as f:
|
||||
@@ -771,6 +772,7 @@ def _load_state_dict_into_meta_model(
|
||||
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
|
||||
device_mesh=None,
|
||||
shard_file=None,
|
||||
weights_only=True,
|
||||
):
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
@@ -800,7 +802,15 @@ def _load_state_dict_into_meta_model(
|
||||
if shard_file.endswith(".safetensors"):
|
||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||
else:
|
||||
bin_state_dict = load_state_dict(shard_file, map_location="cpu")
|
||||
map_location = "cpu"
|
||||
if (
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
bin_state_dict = load_state_dict(shard_file, map_location=map_location, weights_only=weights_only)
|
||||
|
||||
error_msgs = []
|
||||
|
||||
@@ -822,23 +832,36 @@ def _load_state_dict_into_meta_model(
|
||||
if shard_file.endswith(".safetensors")
|
||||
else bin_state_dict[serialized_param_name]
|
||||
)
|
||||
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||
|
||||
old_param = model
|
||||
splits = fixed_param_name.split(".")
|
||||
for split in splits:
|
||||
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
|
||||
old_param = getattr(old_param, split, None)
|
||||
if old_param is None:
|
||||
break
|
||||
|
||||
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
||||
old_param = None
|
||||
|
||||
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
param_casting_dtype = None
|
||||
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
|
||||
|
||||
if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
||||
if (
|
||||
keep_in_fp32_modules is not None
|
||||
and keep_in_fp32_modules.search(fixed_param_name)
|
||||
and dtype == torch.float16
|
||||
):
|
||||
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
||||
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(fixed_param_name):
|
||||
param_casting_dtype = torch.float32
|
||||
else:
|
||||
elif dtype is not None:
|
||||
param_casting_dtype = dtype
|
||||
elif old_param is not None:
|
||||
param_casting_dtype = old_param.dtype
|
||||
|
||||
if device_mesh is not None: # In this case, the param is already on the correct device!
|
||||
module_to_tp, param_type = find_submodule_and_param_name(model, fixed_param_name)
|
||||
module_to_tp, param_type = get_module_from_name(model, fixed_param_name)
|
||||
current_module_plan = None
|
||||
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
|
||||
if plan := re.search(full_tp_plan_, fixed_param_name):
|
||||
@@ -859,8 +882,10 @@ def _load_state_dict_into_meta_model(
|
||||
else:
|
||||
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
|
||||
shard = Shard(0)
|
||||
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
|
||||
if param_casting_dtype is not None:
|
||||
param = param.to(param_casting_dtype)
|
||||
if old_param.is_contiguous():
|
||||
param = param.contiguous()
|
||||
local_parameter = DTensor.from_local(
|
||||
param,
|
||||
device_mesh=device_mesh,
|
||||
@@ -873,9 +898,18 @@ def _load_state_dict_into_meta_model(
|
||||
output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output)
|
||||
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
|
||||
else:
|
||||
module_to_tp.load_state_dict({param_type: param[:]}, strict=False, assign=True)
|
||||
param = param[:]
|
||||
if old_param is not None and old_param.is_contiguous():
|
||||
param = param.contiguous()
|
||||
module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
|
||||
|
||||
else:
|
||||
param = param[:]
|
||||
if param_casting_dtype is not None:
|
||||
param = param.to(param_casting_dtype)
|
||||
if old_param is not None and old_param.is_contiguous():
|
||||
param = param.contiguous()
|
||||
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
else:
|
||||
@@ -887,9 +921,9 @@ def _load_state_dict_into_meta_model(
|
||||
|
||||
if param_device == "disk":
|
||||
if not is_safetensors:
|
||||
offload_index = offload_weight(param[:], fixed_param_name, offload_folder, offload_index)
|
||||
offload_index = offload_weight(param, fixed_param_name, offload_folder, offload_index)
|
||||
elif param_device == "cpu" and state_dict_index is not None:
|
||||
state_dict_index = offload_weight(param[:], fixed_param_name, state_dict_folder, state_dict_index)
|
||||
state_dict_index = offload_weight(param, fixed_param_name, state_dict_folder, state_dict_index)
|
||||
elif (
|
||||
not is_quantized
|
||||
or (not hf_quantizer.requires_parameters_quantization)
|
||||
@@ -906,23 +940,21 @@ def _load_state_dict_into_meta_model(
|
||||
):
|
||||
if is_fsdp_enabled():
|
||||
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
||||
module, param_type = find_submodule_and_param_name(model, fixed_param_name)
|
||||
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
|
||||
param = param[:].to(param_casting_dtype)
|
||||
module, param_type = get_module_from_name(model, fixed_param_name)
|
||||
module.load_state_dict(
|
||||
{param_type: param[:].to(param_device)},
|
||||
{param_type: param.to(param_device)},
|
||||
strict=False,
|
||||
assign=True,
|
||||
)
|
||||
else:
|
||||
hf_quantizer.create_quantized_param(
|
||||
model, param[:], fixed_param_name, param_device, state_dict, unexpected_keys
|
||||
model, param, fixed_param_name, param_device, state_dict, unexpected_keys
|
||||
)
|
||||
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
|
||||
# and then cast it to CPU to avoid excessive memory usage on each GPU
|
||||
# in comparison to the sharded model across GPUs.
|
||||
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
||||
module, param_type = find_submodule_and_param_name(model, fixed_param_name)
|
||||
module, param_type = get_module_from_name(model, fixed_param_name)
|
||||
value = getattr(module, param_type)
|
||||
param_to = "cpu"
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
||||
@@ -4203,7 +4235,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
elif not is_sharded:
|
||||
torch_dtype = get_state_dict_dtype(state_dict)
|
||||
else:
|
||||
one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only)
|
||||
one_state_dict = load_state_dict(
|
||||
resolved_archive_file[0], map_location="meta", weights_only=weights_only
|
||||
)
|
||||
torch_dtype = get_state_dict_dtype(one_state_dict)
|
||||
del one_state_dict # free CPU memory
|
||||
logger.info(
|
||||
@@ -4848,7 +4882,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
else:
|
||||
folder = None
|
||||
|
||||
model.expected_keys = expected_keys
|
||||
model_to_load.expected_keys = expected_keys
|
||||
if device_map is not None:
|
||||
expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
|
||||
if hf_quantizer is None:
|
||||
@@ -4907,6 +4941,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
resolved_archive_file=resolved_archive_file,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
else:
|
||||
# We need to read the state dict as it is meta otherwise
|
||||
@@ -4957,16 +4992,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
|
||||
if shard_file in disk_only_shard_files:
|
||||
continue
|
||||
map_location = None
|
||||
if (
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
state_dict = load_state_dict(
|
||||
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
||||
shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only
|
||||
)
|
||||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
@@ -5006,6 +5033,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
shard_file=shard_file,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
error_msgs += new_error_msgs
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user