protect dtensor import (#38496)

protect
This commit is contained in:
Marc Sun
2025-05-30 17:36:00 +02:00
committed by GitHub
parent 051a8acc9a
commit c7f2b79dd8
2 changed files with 1 additions and 5 deletions

View File

@@ -37,7 +37,6 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVa
from zipfile import is_zipfile from zipfile import is_zipfile
import torch import torch
import torch.distributed.tensor
from huggingface_hub import split_torch_state_dict_into_shards from huggingface_hub import split_torch_state_dict_into_shards
from packaging import version from packaging import version
from torch import Tensor, nn from torch import Tensor, nn

View File

@@ -42,9 +42,6 @@ is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_d
# Cache this result has it's a C FFI call which can be pretty time-consuming # Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available() _torch_distributed_available = torch.distributed.is_available()
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
pass
def softmax_backward_data(parent, grad_output, output, dim, self): def softmax_backward_data(parent, grad_output, output, dim, self):
""" """
@@ -296,7 +293,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id. non-overlapping lifetimes may have the same id.
""" """
if is_torch_greater_or_equal_than_2_0: if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
from torch.distributed.tensor import DTensor from torch.distributed.tensor import DTensor
if isinstance(tensor, DTensor): if isinstance(tensor, DTensor):