From c184550dafcc214fd10cddec98675a8c68a6440f Mon Sep 17 00:00:00 2001 From: Benoqtr <155428839+Benoqtr@users.noreply.github.com> Date: Mon, 23 Jun 2025 17:25:56 +0800 Subject: [PATCH] Fix DTensor import compatibility for PyTorch < 2.5 (#38836) --- src/transformers/modeling_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ce4cdab8b8..0c514ec1bb 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -172,7 +172,8 @@ _is_quantized = False _is_ds_init_called = False _torch_distributed_available = torch.distributed.is_available() -if _torch_distributed_available and is_torch_greater_or_equal("2.5"): +_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5") +if _is_dtensor_available: from torch.distributed.tensor import DTensor @@ -3780,7 +3781,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi for shard_file, tensors in filename_to_tensors: shard = {} for tensor in tensors: - if isinstance(state_dict[tensor], DTensor): + if _is_dtensor_available and isinstance(state_dict[tensor], DTensor): full_tensor = state_dict[tensor].full_tensor() # to get the correctly ordered tensor we need to repack if packed if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):