Fix DTensor import compatibility for PyTorch < 2.5 (#38836)
This commit is contained in:
@@ -172,7 +172,8 @@ _is_quantized = False
|
|||||||
_is_ds_init_called = False
|
_is_ds_init_called = False
|
||||||
_torch_distributed_available = torch.distributed.is_available()
|
_torch_distributed_available = torch.distributed.is_available()
|
||||||
|
|
||||||
if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
|
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
|
||||||
|
if _is_dtensor_available:
|
||||||
from torch.distributed.tensor import DTensor
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
|
||||||
@@ -3780,7 +3781,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
for shard_file, tensors in filename_to_tensors:
|
for shard_file, tensors in filename_to_tensors:
|
||||||
shard = {}
|
shard = {}
|
||||||
for tensor in tensors:
|
for tensor in tensors:
|
||||||
if isinstance(state_dict[tensor], DTensor):
|
if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
|
||||||
full_tensor = state_dict[tensor].full_tensor()
|
full_tensor = state_dict[tensor].full_tensor()
|
||||||
# to get the correctly ordered tensor we need to repack if packed
|
# to get the correctly ordered tensor we need to repack if packed
|
||||||
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
|
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
|
||||||
|
|||||||
Reference in New Issue
Block a user