Handling an exception related to HQQ quantization in modeling (#36702)

* adding exception

* style

* add types
This commit is contained in:
Mohamed Mekkouri
2025-03-13 17:53:36 +01:00
committed by GitHub
parent 09a309d273
commit 4a60bae8e2

View File

@@ -709,9 +709,19 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
def _infer_parameter_dtype( def _infer_parameter_dtype(
model: "PreTrainedModel", param_name: str, empty_param, keep_in_fp32_modules=None model: "PreTrainedModel",
param_name: str,
empty_param: torch.Tensor,
keep_in_fp32_modules: Optional[List[str]] = None,
hf_quantizer: Optional[HfQuantizer] = None,
) -> Union[bool, Optional[torch.dtype]]: ) -> Union[bool, Optional[torch.dtype]]:
try:
old_param = model.get_parameter_or_buffer(param_name) old_param = model.get_parameter_or_buffer(param_name)
except Exception as e:
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ:
return True, None
else:
raise e
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them. # in int/uint/bool and not cast them.
@@ -781,6 +791,7 @@ def _load_state_dict_into_meta_model(
param_name, param_name,
empty_param, empty_param,
keep_in_fp32_modules, keep_in_fp32_modules,
hf_quantizer,
) )
if device_mesh is not None: # In this case, the param is already on the correct device! if device_mesh is not None: # In this case, the param is already on the correct device!