fix torch_dtype, contiguous, and load_state_dict regression (#36512)

* fix regression

* fix param

* fix load_state_dict

* style

* better fix for module

* fix tests

* quick fix for now

* rm print
This commit is contained in:
Marc Sun
2025-03-03 18:35:37 +01:00
committed by GitHub
parent 3e83ee75ec
commit 0463901c92

View File

@@ -67,6 +67,7 @@ from .pytorch_utils import ( # noqa: F401
translate_to_torch_parallel_style, translate_to_torch_parallel_style,
) )
from .quantizers import AutoHfQuantizer, HfQuantizer from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
from .safetensors_conversion import auto_conversion from .safetensors_conversion import auto_conversion
from .utils import ( from .utils import (
ACCELERATE_MIN_VERSION, ACCELERATE_MIN_VERSION,
@@ -536,11 +537,11 @@ str_to_torch_dtype = {
def load_state_dict( def load_state_dict(
checkpoint_file: Union[str, os.PathLike], checkpoint_file: Union[str, os.PathLike],
is_quantized: bool = False, is_quantized: bool = False,
map_location: Optional[Union[str, torch.device]] = "meta", map_location: Optional[Union[str, torch.device]] = "cpu",
weights_only: bool = True, weights_only: bool = True,
): ):
""" """
Reads a `safetensor` or a `.bin` checkpoint file into `meta` if requested. Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
""" """
if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
with safe_open(checkpoint_file, framework="pt") as f: with safe_open(checkpoint_file, framework="pt") as f:
@@ -771,6 +772,7 @@ def _load_state_dict_into_meta_model(
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
device_mesh=None, device_mesh=None,
shard_file=None, shard_file=None,
weights_only=True,
): ):
""" """
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
@@ -800,7 +802,15 @@ def _load_state_dict_into_meta_model(
if shard_file.endswith(".safetensors"): if shard_file.endswith(".safetensors"):
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
else: else:
bin_state_dict = load_state_dict(shard_file, map_location="cpu") map_location = "cpu"
if (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
bin_state_dict = load_state_dict(shard_file, map_location=map_location, weights_only=weights_only)
error_msgs = [] error_msgs = []
@@ -822,23 +832,36 @@ def _load_state_dict_into_meta_model(
if shard_file.endswith(".safetensors") if shard_file.endswith(".safetensors")
else bin_state_dict[serialized_param_name] else bin_state_dict[serialized_param_name]
) )
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
splits = fixed_param_name.split(".")
for split in splits:
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
old_param = getattr(old_param, split, None)
if old_param is None:
break
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them. # in int/uint/bool and not cast them.
param_casting_dtype = None param_casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(fixed_param_name):
if (
keep_in_fp32_modules is not None
and keep_in_fp32_modules.search(fixed_param_name)
and dtype == torch.float16
):
param_casting_dtype = torch.float32 param_casting_dtype = torch.float32
else: elif dtype is not None:
param_casting_dtype = dtype param_casting_dtype = dtype
elif old_param is not None:
param_casting_dtype = old_param.dtype
if device_mesh is not None: # In this case, the param is already on the correct device! if device_mesh is not None: # In this case, the param is already on the correct device!
module_to_tp, param_type = find_submodule_and_param_name(model, fixed_param_name) module_to_tp, param_type = get_module_from_name(model, fixed_param_name)
current_module_plan = None current_module_plan = None
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+") full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
if plan := re.search(full_tp_plan_, fixed_param_name): if plan := re.search(full_tp_plan_, fixed_param_name):
@@ -859,8 +882,10 @@ def _load_state_dict_into_meta_model(
else: else:
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :] param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
shard = Shard(0) shard = Shard(0)
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype: if param_casting_dtype is not None:
param = param.to(param_casting_dtype) param = param.to(param_casting_dtype)
if old_param.is_contiguous():
param = param.contiguous()
local_parameter = DTensor.from_local( local_parameter = DTensor.from_local(
param, param,
device_mesh=device_mesh, device_mesh=device_mesh,
@@ -873,9 +898,18 @@ def _load_state_dict_into_meta_model(
output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output) output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output)
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn) distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
else: else:
module_to_tp.load_state_dict({param_type: param[:]}, strict=False, assign=True) param = param[:]
if old_param is not None and old_param.is_contiguous():
param = param.contiguous()
module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
else: else:
param = param[:]
if param_casting_dtype is not None:
param = param.to(param_casting_dtype)
if old_param is not None and old_param.is_contiguous():
param = param.contiguous()
if device_map is None: if device_map is None:
param_device = "cpu" param_device = "cpu"
else: else:
@@ -887,9 +921,9 @@ def _load_state_dict_into_meta_model(
if param_device == "disk": if param_device == "disk":
if not is_safetensors: if not is_safetensors:
offload_index = offload_weight(param[:], fixed_param_name, offload_folder, offload_index) offload_index = offload_weight(param, fixed_param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None: elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param[:], fixed_param_name, state_dict_folder, state_dict_index) state_dict_index = offload_weight(param, fixed_param_name, state_dict_folder, state_dict_index)
elif ( elif (
not is_quantized not is_quantized
or (not hf_quantizer.requires_parameters_quantization) or (not hf_quantizer.requires_parameters_quantization)
@@ -906,23 +940,21 @@ def _load_state_dict_into_meta_model(
): ):
if is_fsdp_enabled(): if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta" param_device = "cpu" if is_local_dist_rank_0() else "meta"
module, param_type = find_submodule_and_param_name(model, fixed_param_name) module, param_type = get_module_from_name(model, fixed_param_name)
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
param = param[:].to(param_casting_dtype)
module.load_state_dict( module.load_state_dict(
{param_type: param[:].to(param_device)}, {param_type: param.to(param_device)},
strict=False, strict=False,
assign=True, assign=True,
) )
else: else:
hf_quantizer.create_quantized_param( hf_quantizer.create_quantized_param(
model, param[:], fixed_param_name, param_device, state_dict, unexpected_keys model, param, fixed_param_name, param_device, state_dict, unexpected_keys
) )
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU # and then cast it to CPU to avoid excessive memory usage on each GPU
# in comparison to the sharded model across GPUs. # in comparison to the sharded model across GPUs.
if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, param_type = find_submodule_and_param_name(model, fixed_param_name) module, param_type = get_module_from_name(model, fixed_param_name)
value = getattr(module, param_type) value = getattr(module, param_type)
param_to = "cpu" param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0(): if is_fsdp_enabled() and not is_local_dist_rank_0():
@@ -4203,7 +4235,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif not is_sharded: elif not is_sharded:
torch_dtype = get_state_dict_dtype(state_dict) torch_dtype = get_state_dict_dtype(state_dict)
else: else:
one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only) one_state_dict = load_state_dict(
resolved_archive_file[0], map_location="meta", weights_only=weights_only
)
torch_dtype = get_state_dict_dtype(one_state_dict) torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory del one_state_dict # free CPU memory
logger.info( logger.info(
@@ -4848,7 +4882,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
folder = None folder = None
model.expected_keys = expected_keys model_to_load.expected_keys = expected_keys
if device_map is not None: if device_map is not None:
expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
if hf_quantizer is None: if hf_quantizer is None:
@@ -4907,6 +4941,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
unexpected_keys=unexpected_keys, unexpected_keys=unexpected_keys,
device_mesh=device_mesh, device_mesh=device_mesh,
resolved_archive_file=resolved_archive_file, resolved_archive_file=resolved_archive_file,
weights_only=weights_only,
) )
else: else:
# We need to read the state dict as it is meta otherwise # We need to read the state dict as it is meta otherwise
@@ -4957,16 +4992,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
if shard_file in disk_only_shard_files: if shard_file in disk_only_shard_files:
continue continue
map_location = None
if (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
state_dict = load_state_dict( state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only
) )
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
@@ -5006,6 +5033,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
unexpected_keys=unexpected_keys, unexpected_keys=unexpected_keys,
device_mesh=device_mesh, device_mesh=device_mesh,
shard_file=shard_file, shard_file=shard_file,
weights_only=weights_only,
) )
error_msgs += new_error_msgs error_msgs += new_error_msgs
else: else: