From 82603b6cc284dbdf2b7a7cf070feb6a2c3bb53cf Mon Sep 17 00:00:00 2001 From: Matej Sirovatka <54212263+S1ro1@users.noreply.github.com> Date: Wed, 23 Jul 2025 14:27:36 +0200 Subject: [PATCH] Allow `device_mesh` have multiple dim (#38949) * Feat: something * Feat: initial changes * tmp changes to unblock * Refactor * remove todo * Feat: docstring --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/modeling_utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index dc4e997fef..1577e7db58 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4581,6 +4581,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH A torch tensor parallel degree. If not provided would default to world size. device_mesh (`torch.distributed.DeviceMesh`, *optional*): A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now. + If provided, it has to contain dimension named `"tp"` which will be used for tensor parallelism offload_folder (`str` or `os.PathLike`, *optional*): If the `device_map` contains any value `"disk"`, the folder where we will offload weights. offload_state_dict (`bool`, *optional*): @@ -4718,13 +4719,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH # We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple # `device_map` pointing to the correct device if tp_plan is not None: - if device_mesh is None and tp_plan is not None: + if device_mesh is None: tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None) else: - # TODO: make device_mesh support multiple dimensions - if device_mesh.ndim != 1: - raise ValueError("device_mesh must be 1 dimensional and will be used for TP") - device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"])) + if "tp" not in device_mesh.mesh_dim_names: + raise ValueError( + "When using `tp_plan`, the `device_mesh` must contain a 'tp' dimension. " + "Please provide a valid `device_mesh`." + ) + device_mesh = device_mesh["tp"] + tp_size = device_mesh["tp"].size() + device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}") if tp_size is None: tp_size = torch.distributed.get_world_size()