Revert parallelism temporarily (#38240)

* Revert "Protect ParallelInterface"

This reverts commit cb513e35f9.

* Revert "parallelism goes brrr (#37877)"

This reverts commit 1c2f36b480.

* Empty commit
This commit is contained in:
Lysandre Debut
2025-05-20 22:43:04 +02:00
committed by GitHub
parent feec294dea
commit 711d78d104
8 changed files with 138 additions and 1522 deletions

View File

@@ -62,9 +62,8 @@ from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.tensor_parallel import (
ALL_PARALLEL_STYLES,
SUPPORTED_TP_STYLES,
_get_parameter_tp_plan,
initialize_tensor_parallelism,
repack_weights,
replace_state_dict_local_with_dtensor,
shard_and_distribute_module,
@@ -798,7 +797,7 @@ def _load_state_dict_into_meta_model(
param_name,
casting_dtype,
to_contiguous,
device_mesh.get_local_rank(),
int(os.environ["RANK"]), # the rank
device_mesh,
)
else:
@@ -1965,9 +1964,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
for _, v in self._tp_plan.items():
if v not in ALL_PARALLEL_STYLES:
if v not in SUPPORTED_TP_STYLES:
raise ValueError(
f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}"
f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}"
)
def dequantize(self):
@@ -3560,7 +3559,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
if safe_serialization:
# TODO: fix safe_serialization for tied weights
# Safetensors does not allow tensor aliasing.
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
@@ -4042,8 +4040,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
tp_size (`str`, *optional*):
A torch tensor parallel degree. If not provided would default to world size.
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*):
@@ -4141,7 +4137,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
device_mesh = kwargs.pop("device_mesh", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
@@ -4177,13 +4172,59 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
# `device_map` pointing to the correct device
if device_mesh is None:
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
else:
# TODO: make device_mesh support multiple dimensions
if device_mesh.ndim == 1:
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
device_mesh = None
if tp_plan is not None:
if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
if not torch.distributed.is_initialized():
try:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if device_type == "cuda":
torch.distributed.init_process_group(
"nccl", rank=rank, world_size=world_size, init_method="env://"
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "cpu":
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
elif device_type == "xpu":
torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size)
torch.xpu.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "hpu":
torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size)
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
except Exception as e:
raise EnvironmentError(
"We tried to initialize torch.distributed for you, but it failed, make"
"sure you init torch distributed in your script to use `tp_plan='auto'`"
) from e
# Get device with index assuming equal number of devices per host
if device_type == "xpu":
index = torch.xpu.current_device()
elif device_type == "hpu":
index = torch.hpu.current_device()
else:
index = None if device_type == "cpu" else torch.cuda.current_device()
tp_device = torch.device(device_type, index)
if index is not None and index > 0:
import sys
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
# This is the easiest way to dispatch to the current process device
device_map = tp_device
# Assuming sharding the model onto the world when tp_size not provided
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
if use_auth_token is not None:
warnings.warn(
@@ -5101,7 +5142,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
name,
casting_dtype,
to_contiguous,
device_mesh.get_local_rank(),
os.environ["RANK"],
device_mesh,
)