Fix DTensor import compatibility for PyTorch < 2.5 (#38836)

This commit is contained in:
Benoqtr
2025-06-23 17:25:56 +08:00
committed by GitHub
parent 984ff89e73
commit c184550daf

View File

@@ -172,7 +172,8 @@ _is_quantized = False
_is_ds_init_called = False
_torch_distributed_available = torch.distributed.is_available()
if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
if _is_dtensor_available:
from torch.distributed.tensor import DTensor
@@ -3780,7 +3781,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
if isinstance(state_dict[tensor], DTensor):
if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
full_tensor = state_dict[tensor].full_tensor()
# to get the correctly ordered tensor we need to repack if packed
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):