Fix DTensor import compatibility for PyTorch < 2.5 (#38836)

This commit is contained in:
Benoqtr
2025-06-23 17:25:56 +08:00
committed by GitHub
parent 984ff89e73
commit c184550daf

View File

@@ -172,7 +172,8 @@ _is_quantized = False
_is_ds_init_called = False _is_ds_init_called = False
_torch_distributed_available = torch.distributed.is_available() _torch_distributed_available = torch.distributed.is_available()
if _torch_distributed_available and is_torch_greater_or_equal("2.5"): _is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
if _is_dtensor_available:
from torch.distributed.tensor import DTensor from torch.distributed.tensor import DTensor
@@ -3780,7 +3781,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
for shard_file, tensors in filename_to_tensors: for shard_file, tensors in filename_to_tensors:
shard = {} shard = {}
for tensor in tensors: for tensor in tensors:
if isinstance(state_dict[tensor], DTensor): if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
full_tensor = state_dict[tensor].full_tensor() full_tensor = state_dict[tensor].full_tensor()
# to get the correctly ordered tensor we need to repack if packed # to get the correctly ordered tensor we need to repack if packed
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",): if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):