Fix import for torch 2.0, 2.1 - guard typehint for "device_mesh" (#36768)

* Fix device_mesh

* Remove rebase leftover
This commit is contained in:
Pavel Iakubovskii
2025-03-20 11:55:47 +00:00
committed by GitHub
parent 388e6659bf
commit cf8091c017

View File

@@ -755,7 +755,7 @@ def _load_state_dict_into_meta_model(
is_safetensors: bool = False, is_safetensors: bool = False,
keep_in_fp32_modules: Optional[List[str]] = None, keep_in_fp32_modules: Optional[List[str]] = None,
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
) -> Tuple[Optional[Dict], Optional[Dict]]: ) -> Tuple[Optional[Dict], Optional[Dict]]:
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta """Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
@@ -4665,7 +4665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
hf_quantizer: Optional[HfQuantizer] = None, hf_quantizer: Optional[HfQuantizer] = None,
keep_in_fp32_modules: Optional[List[str]] = None, keep_in_fp32_modules: Optional[List[str]] = None,
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
key_mapping: Optional[Dict[str, str]] = None, key_mapping: Optional[Dict[str, str]] = None,
weights_only: bool = True, weights_only: bool = True,
_fast_init: bool = True, _fast_init: bool = True,