@@ -732,6 +732,8 @@ def _infer_parameter_dtype(
|
|||||||
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name):
|
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name):
|
||||||
casting_dtype = torch.float32
|
casting_dtype = torch.float32
|
||||||
# Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
|
# Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
|
||||||
|
elif hf_quantizer is not None:
|
||||||
|
casting_dtype = model.config._pre_quantization_dtype
|
||||||
else:
|
else:
|
||||||
casting_dtype = old_param.dtype
|
casting_dtype = old_param.dtype
|
||||||
return old_param is not None and old_param.is_contiguous(), casting_dtype
|
return old_param is not None and old_param.is_contiguous(), casting_dtype
|
||||||
@@ -754,7 +756,6 @@ def _load_state_dict_into_meta_model(
|
|||||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||||
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
|
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
|
||||||
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||||
weights_only=True,
|
|
||||||
) -> Tuple[Optional[Dict], Optional[Dict]]:
|
) -> Tuple[Optional[Dict], Optional[Dict]]:
|
||||||
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
|
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
|
||||||
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
|
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
|
||||||
@@ -4883,7 +4884,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||||
unexpected_keys=unexpected_keys,
|
unexpected_keys=unexpected_keys,
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
weights_only=weights_only,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assign_params = check_support_param_buffer_assignment(model_to_load, state_dict)
|
assign_params = check_support_param_buffer_assignment(model_to_load, state_dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user