Fix import for torch 2.0, 2.1 - guard typehint for "device_mesh" (#36768)
* Fix device_mesh * Remove rebase leftover
This commit is contained in:
committed by
GitHub
parent
388e6659bf
commit
cf8091c017
@@ -755,7 +755,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
is_safetensors: bool = False,
|
is_safetensors: bool = False,
|
||||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||||
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
|
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
|
||||||
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
||||||
) -> Tuple[Optional[Dict], Optional[Dict]]:
|
) -> Tuple[Optional[Dict], Optional[Dict]]:
|
||||||
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
|
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
|
||||||
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
|
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
|
||||||
@@ -4665,7 +4665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
hf_quantizer: Optional[HfQuantizer] = None,
|
hf_quantizer: Optional[HfQuantizer] = None,
|
||||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||||
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
||||||
key_mapping: Optional[Dict[str, str]] = None,
|
key_mapping: Optional[Dict[str, str]] = None,
|
||||||
weights_only: bool = True,
|
weights_only: bool = True,
|
||||||
_fast_init: bool = True,
|
_fast_init: bool = True,
|
||||||
|
|||||||
Reference in New Issue
Block a user