From 4a60bae8e236140a4e711f773c975707b8acd032 Mon Sep 17 00:00:00 2001 From: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Date: Thu, 13 Mar 2025 17:53:36 +0100 Subject: [PATCH] Handling an exception related to HQQ quantization in modeling (#36702) * adding exception * style * add types --- src/transformers/modeling_utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fc94ec093d..723dc53b34 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -709,9 +709,19 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor] def _infer_parameter_dtype( - model: "PreTrainedModel", param_name: str, empty_param, keep_in_fp32_modules=None + model: "PreTrainedModel", + param_name: str, + empty_param: torch.Tensor, + keep_in_fp32_modules: Optional[List[str]] = None, + hf_quantizer: Optional[HfQuantizer] = None, ) -> Union[bool, Optional[torch.dtype]]: - old_param = model.get_parameter_or_buffer(param_name) + try: + old_param = model.get_parameter_or_buffer(param_name) + except Exception as e: + if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ: + return True, None + else: + raise e is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params # in int/uint/bool and not cast them. @@ -781,6 +791,7 @@ def _load_state_dict_into_meta_model( param_name, empty_param, keep_in_fp32_modules, + hf_quantizer, ) if device_mesh is not None: # In this case, the param is already on the correct device!