From 0463901c92e08cefbccf19f409b6cc43c153352d Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Mon, 3 Mar 2025 18:35:37 +0100 Subject: [PATCH] fix torch_dtype, contiguous, and load_state_dict regression (#36512) * fix regression * fix param * fix load_state_dict * style * better fix for module * fix tests * quick fix for now * rm print --- src/transformers/modeling_utils.py | 94 +++++++++++++++++++----------- 1 file changed, 61 insertions(+), 33 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a016f6013f..9cb3de74c2 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -67,6 +67,7 @@ from .pytorch_utils import ( # noqa: F401 translate_to_torch_parallel_style, ) from .quantizers import AutoHfQuantizer, HfQuantizer +from .quantizers.quantizers_utils import get_module_from_name from .safetensors_conversion import auto_conversion from .utils import ( ACCELERATE_MIN_VERSION, @@ -536,11 +537,11 @@ str_to_torch_dtype = { def load_state_dict( checkpoint_file: Union[str, os.PathLike], is_quantized: bool = False, - map_location: Optional[Union[str, torch.device]] = "meta", + map_location: Optional[Union[str, torch.device]] = "cpu", weights_only: bool = True, ): """ - Reads a `safetensor` or a `.bin` checkpoint file into `meta` if requested. + Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default. """ if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): with safe_open(checkpoint_file, framework="pt") as f: @@ -771,6 +772,7 @@ def _load_state_dict_into_meta_model( unexpected_keys=None, # passing `unexpected` for cleanup from quantization items device_mesh=None, shard_file=None, + weights_only=True, ): """ This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its @@ -800,7 +802,15 @@ def _load_state_dict_into_meta_model( if shard_file.endswith(".safetensors"): file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) else: - bin_state_dict = load_state_dict(shard_file, map_location="cpu") + map_location = "cpu" + if ( + device_map is not None + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] + ): + map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + bin_state_dict = load_state_dict(shard_file, map_location=map_location, weights_only=weights_only) error_msgs = [] @@ -822,23 +832,36 @@ def _load_state_dict_into_meta_model( if shard_file.endswith(".safetensors") else bin_state_dict[serialized_param_name] ) + + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which + # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. + # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 + + old_param = model + splits = fixed_param_name.split(".") + for split in splits: + # We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys. + old_param = getattr(old_param, split, None) + if old_param is None: + break + + if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): + old_param = None + # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params # in int/uint/bool and not cast them. param_casting_dtype = None is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn - - if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: - if ( - keep_in_fp32_modules is not None - and keep_in_fp32_modules.search(fixed_param_name) - and dtype == torch.float16 - ): + if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: + if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(fixed_param_name): param_casting_dtype = torch.float32 - else: + elif dtype is not None: param_casting_dtype = dtype + elif old_param is not None: + param_casting_dtype = old_param.dtype if device_mesh is not None: # In this case, the param is already on the correct device! - module_to_tp, param_type = find_submodule_and_param_name(model, fixed_param_name) + module_to_tp, param_type = get_module_from_name(model, fixed_param_name) current_module_plan = None full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+") if plan := re.search(full_tp_plan_, fixed_param_name): @@ -859,8 +882,10 @@ def _load_state_dict_into_meta_model( else: param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :] shard = Shard(0) - if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype: + if param_casting_dtype is not None: param = param.to(param_casting_dtype) + if old_param.is_contiguous(): + param = param.contiguous() local_parameter = DTensor.from_local( param, device_mesh=device_mesh, @@ -873,9 +898,18 @@ def _load_state_dict_into_meta_model( output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output) distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn) else: - module_to_tp.load_state_dict({param_type: param[:]}, strict=False, assign=True) + param = param[:] + if old_param is not None and old_param.is_contiguous(): + param = param.contiguous() + module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True) else: + param = param[:] + if param_casting_dtype is not None: + param = param.to(param_casting_dtype) + if old_param is not None and old_param.is_contiguous(): + param = param.contiguous() + if device_map is None: param_device = "cpu" else: @@ -887,9 +921,9 @@ def _load_state_dict_into_meta_model( if param_device == "disk": if not is_safetensors: - offload_index = offload_weight(param[:], fixed_param_name, offload_folder, offload_index) + offload_index = offload_weight(param, fixed_param_name, offload_folder, offload_index) elif param_device == "cpu" and state_dict_index is not None: - state_dict_index = offload_weight(param[:], fixed_param_name, state_dict_folder, state_dict_index) + state_dict_index = offload_weight(param, fixed_param_name, state_dict_folder, state_dict_index) elif ( not is_quantized or (not hf_quantizer.requires_parameters_quantization) @@ -906,23 +940,21 @@ def _load_state_dict_into_meta_model( ): if is_fsdp_enabled(): param_device = "cpu" if is_local_dist_rank_0() else "meta" - module, param_type = find_submodule_and_param_name(model, fixed_param_name) - if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype: - param = param[:].to(param_casting_dtype) + module, param_type = get_module_from_name(model, fixed_param_name) module.load_state_dict( - {param_type: param[:].to(param_device)}, + {param_type: param.to(param_device)}, strict=False, assign=True, ) else: hf_quantizer.create_quantized_param( - model, param[:], fixed_param_name, param_device, state_dict, unexpected_keys + model, param, fixed_param_name, param_device, state_dict, unexpected_keys ) # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU # and then cast it to CPU to avoid excessive memory usage on each GPU # in comparison to the sharded model across GPUs. if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): - module, param_type = find_submodule_and_param_name(model, fixed_param_name) + module, param_type = get_module_from_name(model, fixed_param_name) value = getattr(module, param_type) param_to = "cpu" if is_fsdp_enabled() and not is_local_dist_rank_0(): @@ -4203,7 +4235,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix elif not is_sharded: torch_dtype = get_state_dict_dtype(state_dict) else: - one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only) + one_state_dict = load_state_dict( + resolved_archive_file[0], map_location="meta", weights_only=weights_only + ) torch_dtype = get_state_dict_dtype(one_state_dict) del one_state_dict # free CPU memory logger.info( @@ -4848,7 +4882,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix else: folder = None - model.expected_keys = expected_keys + model_to_load.expected_keys = expected_keys if device_map is not None: expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) if hf_quantizer is None: @@ -4907,6 +4941,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix unexpected_keys=unexpected_keys, device_mesh=device_mesh, resolved_archive_file=resolved_archive_file, + weights_only=weights_only, ) else: # We need to read the state dict as it is meta otherwise @@ -4957,16 +4992,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. if shard_file in disk_only_shard_files: continue - map_location = None - if ( - device_map is not None - and hf_quantizer is not None - and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO - and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] - ): - map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) state_dict = load_state_dict( - shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only + shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only ) # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not @@ -5006,6 +5033,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix unexpected_keys=unexpected_keys, device_mesh=device_mesh, shard_file=shard_file, + weights_only=weights_only, ) error_msgs += new_error_msgs else: