Handling an exception related to HQQ quantization in modeling (#36702)
* adding exception * style * add types
This commit is contained in:
@@ -709,9 +709,19 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
|
|||||||
|
|
||||||
|
|
||||||
def _infer_parameter_dtype(
|
def _infer_parameter_dtype(
|
||||||
model: "PreTrainedModel", param_name: str, empty_param, keep_in_fp32_modules=None
|
model: "PreTrainedModel",
|
||||||
|
param_name: str,
|
||||||
|
empty_param: torch.Tensor,
|
||||||
|
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||||
|
hf_quantizer: Optional[HfQuantizer] = None,
|
||||||
) -> Union[bool, Optional[torch.dtype]]:
|
) -> Union[bool, Optional[torch.dtype]]:
|
||||||
old_param = model.get_parameter_or_buffer(param_name)
|
try:
|
||||||
|
old_param = model.get_parameter_or_buffer(param_name)
|
||||||
|
except Exception as e:
|
||||||
|
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ:
|
||||||
|
return True, None
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||||
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
|
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
|
||||||
# in int/uint/bool and not cast them.
|
# in int/uint/bool and not cast them.
|
||||||
@@ -781,6 +791,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
param_name,
|
param_name,
|
||||||
empty_param,
|
empty_param,
|
||||||
keep_in_fp32_modules,
|
keep_in_fp32_modules,
|
||||||
|
hf_quantizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if device_mesh is not None: # In this case, the param is already on the correct device!
|
if device_mesh is not None: # In this case, the param is already on the correct device!
|
||||||
|
|||||||
Reference in New Issue
Block a user