Fix casting dtype for qunatization (#36799)

* fix

* remove print
This commit is contained in:
Marc Sun
2025-03-18 18:46:03 +01:00
committed by GitHub
parent 30580f035b
commit 14b597f518

View File

@@ -732,6 +732,8 @@ def _infer_parameter_dtype(
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name): if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name):
casting_dtype = torch.float32 casting_dtype = torch.float32
# Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes # Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
elif hf_quantizer is not None:
casting_dtype = model.config._pre_quantization_dtype
else: else:
casting_dtype = old_param.dtype casting_dtype = old_param.dtype
return old_param is not None and old_param.is_contiguous(), casting_dtype return old_param is not None and old_param.is_contiguous(), casting_dtype
@@ -754,7 +756,6 @@ def _load_state_dict_into_meta_model(
keep_in_fp32_modules: Optional[List[str]] = None, keep_in_fp32_modules: Optional[List[str]] = None,
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
weights_only=True,
) -> Tuple[Optional[Dict], Optional[Dict]]: ) -> Tuple[Optional[Dict], Optional[Dict]]:
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta """Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
@@ -4883,7 +4884,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules=keep_in_fp32_modules, keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys, unexpected_keys=unexpected_keys,
device_mesh=device_mesh, device_mesh=device_mesh,
weights_only=weights_only,
) )
else: else:
assign_params = check_support_param_buffer_assignment(model_to_load, state_dict) assign_params = check_support_param_buffer_assignment(model_to_load, state_dict)