Fix DTensor import compatibility for PyTorch < 2.5 (#38836)
This commit is contained in:
@@ -172,7 +172,8 @@ _is_quantized = False
|
||||
_is_ds_init_called = False
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
|
||||
if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
|
||||
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
|
||||
if _is_dtensor_available:
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
|
||||
@@ -3780,7 +3781,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {}
|
||||
for tensor in tensors:
|
||||
if isinstance(state_dict[tensor], DTensor):
|
||||
if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
|
||||
full_tensor = state_dict[tensor].full_tensor()
|
||||
# to get the correctly ordered tensor we need to repack if packed
|
||||
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
|
||||
|
||||
Reference in New Issue
Block a user