From cf8091c017533c03be73b84ab535ae9c80924796 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Thu, 20 Mar 2025 11:55:47 +0000 Subject: [PATCH] Fix import for torch 2.0, 2.1 - guard typehint for "device_mesh" (#36768) * Fix device_mesh * Remove rebase leftover --- src/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a4b13ca0f2..c4e8a960af 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -755,7 +755,7 @@ def _load_state_dict_into_meta_model( is_safetensors: bool = False, keep_in_fp32_modules: Optional[List[str]] = None, unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items - device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, + device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None, ) -> Tuple[Optional[Dict], Optional[Dict]]: """Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded @@ -4665,7 +4665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix dtype: Optional[torch.dtype] = None, hf_quantizer: Optional[HfQuantizer] = None, keep_in_fp32_modules: Optional[List[str]] = None, - device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, + device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None, key_mapping: Optional[Dict[str, str]] = None, weights_only: bool = True, _fast_init: bool = True,