From c7f2b79dd853c28e27076686c331a1b702ecd38e Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Fri, 30 May 2025 17:36:00 +0200 Subject: [PATCH] protect dtensor import (#38496) protect --- src/transformers/modeling_utils.py | 1 - src/transformers/pytorch_utils.py | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5e972b3423..9639ff7ce0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -37,7 +37,6 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVa from zipfile import is_zipfile import torch -import torch.distributed.tensor from huggingface_hub import split_torch_state_dict_into_shards from packaging import version from torch import Tensor, nn diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index ca60a05a26..9bb02bff96 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -42,9 +42,6 @@ is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_d # Cache this result has it's a C FFI call which can be pretty time-consuming _torch_distributed_available = torch.distributed.is_available() -if is_torch_greater_or_equal("2.5") and _torch_distributed_available: - pass - def softmax_backward_data(parent, grad_output, output, dim, self): """ @@ -296,7 +293,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]: guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id. """ - if is_torch_greater_or_equal_than_2_0: + if _torch_distributed_available and is_torch_greater_or_equal("2.5"): from torch.distributed.tensor import DTensor if isinstance(tensor, DTensor):