Compare commits

...

3 Commits

Author SHA1 Message Date
Arthur Zucker
f4fc42216c v 4.52.3
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2025-05-22 16:29:44 +02:00
Marc Sun
48459c97d7 Fix tp error when torch distributed is already initialized (#38294)
fix tp error
2025-05-22 16:29:24 +02:00
Arthur
597e159145 Protect ParallelInterface (#38262)
Co-authored-by: Lysandre <hi@lysand.re>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
2025-05-22 16:29:16 +02:00
3 changed files with 25 additions and 16 deletions

View File

@@ -451,7 +451,7 @@ install_requires = [
setup(
name="transformers",
version="4.52.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.52.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.52.2"
__version__ = "4.52.3"
from pathlib import Path
from typing import TYPE_CHECKING

View File

@@ -52,6 +52,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
current_device = getattr(torch, device_type)
if not torch.distributed.is_initialized():
try:
rank = int(os.environ["RANK"])
@@ -73,6 +74,9 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
"We tried to initialize torch.distributed for you, but it failed. Make "
"sure you init torch distributed in your script to use `tp_plan='auto'`."
) from e
if device_type != "cpu":
current_device.set_device(int(os.environ["LOCAL_RANK"]))
index = current_device.current_device() if device_type != "cpu" else None
tp_device = torch.device(device_type, index)
@@ -729,23 +733,24 @@ class ParallelInterface(MutableMapping):
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given function)
_global_mapping = {
"colwise": ColwiseParallel(),
"rowwise": RowwiseParallel(),
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
"rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
"local_colwise": ColwiseParallel(use_dtensor=False),
"local_rowwise": RowwiseParallel(use_dtensor=False),
"local": IsolatedParallel(),
"gather": GatherParallel(),
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
"sequence_parallel": SequenceParallel(),
"replicate": ReplicateParallel(),
}
def __init__(self):
self._local_mapping = {}
ParallelInterface._global_mapping = {
"colwise": ColwiseParallel(),
"rowwise": RowwiseParallel(),
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
"rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
"local_colwise": ColwiseParallel(use_dtensor=False),
"local_rowwise": RowwiseParallel(use_dtensor=False),
"local": IsolatedParallel(),
"gather": GatherParallel(),
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
"sequence_parallel": SequenceParallel(),
"replicate": ReplicateParallel(),
}
def __getitem__(self, key):
# First check if instance has a local override
if key in self._local_mapping:
@@ -775,7 +780,11 @@ class ParallelInterface(MutableMapping):
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
else:
ALL_PARALLEL_STYLES = None
def convert_local_tensor_to_dtensor(