Fix import for torch 2.0, 2.1 - guard typehint for "device_mesh" (#36768)

* Fix device_mesh

* Remove rebase leftover
This commit is contained in:
Pavel Iakubovskii
2025-03-20 11:55:47 +00:00
committed by GitHub
parent 388e6659bf
commit cf8091c017

View File

@@ -755,7 +755,7 @@ def _load_state_dict_into_meta_model(
is_safetensors: bool = False,
keep_in_fp32_modules: Optional[List[str]] = None,
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
) -> Tuple[Optional[Dict], Optional[Dict]]:
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
@@ -4665,7 +4665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype: Optional[torch.dtype] = None,
hf_quantizer: Optional[HfQuantizer] = None,
keep_in_fp32_modules: Optional[List[str]] = None,
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
key_mapping: Optional[Dict[str, str]] = None,
weights_only: bool = True,
_fast_init: bool = True,