Handling an exception related to HQQ quantization in modeling (#36702)
* adding exception * style * add types
This commit is contained in:
@@ -709,9 +709,19 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
def _infer_parameter_dtype(
|
||||
model: "PreTrainedModel", param_name: str, empty_param, keep_in_fp32_modules=None
|
||||
model: "PreTrainedModel",
|
||||
param_name: str,
|
||||
empty_param: torch.Tensor,
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
hf_quantizer: Optional[HfQuantizer] = None,
|
||||
) -> Union[bool, Optional[torch.dtype]]:
|
||||
old_param = model.get_parameter_or_buffer(param_name)
|
||||
try:
|
||||
old_param = model.get_parameter_or_buffer(param_name)
|
||||
except Exception as e:
|
||||
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ:
|
||||
return True, None
|
||||
else:
|
||||
raise e
|
||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
@@ -781,6 +791,7 @@ def _load_state_dict_into_meta_model(
|
||||
param_name,
|
||||
empty_param,
|
||||
keep_in_fp32_modules,
|
||||
hf_quantizer,
|
||||
)
|
||||
|
||||
if device_mesh is not None: # In this case, the param is already on the correct device!
|
||||
|
||||
Reference in New Issue
Block a user