Fix import for torch 2.0, 2.1 - guard typehint for "device_mesh" (#36768)
* Fix device_mesh * Remove rebase leftover
This commit is contained in:
committed by
GitHub
parent
388e6659bf
commit
cf8091c017
@@ -755,7 +755,7 @@ def _load_state_dict_into_meta_model(
|
||||
is_safetensors: bool = False,
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
|
||||
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
||||
) -> Tuple[Optional[Dict], Optional[Dict]]:
|
||||
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
|
||||
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
|
||||
@@ -4665,7 +4665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
hf_quantizer: Optional[HfQuantizer] = None,
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
||||
key_mapping: Optional[Dict[str, str]] = None,
|
||||
weights_only: bool = True,
|
||||
_fast_init: bool = True,
|
||||
|
||||
Reference in New Issue
Block a user