From 4c7da9fedf185cecda3de3945fa4c84e5c7ca996 Mon Sep 17 00:00:00 2001 From: Matej Sirovatka <54212263+S1ro1@users.noreply.github.com> Date: Mon, 28 Jul 2025 14:29:58 +0200 Subject: [PATCH] PATCH: add back n-dim device-mesh + fix tp trainer saving (#39693) * Feat: something * Feat: initial changes * tmp changes to unblock * Refactor * remove todo * Feat: docstring * Fix: saving of distributed model in trainer * Fix: distributed saving with trainer * Feat: add pure tp saving * Only require tp dim if ndim > 1 * Fix: default to None * Fix: better comments/errors * Fix: properly check tp_size attribute * Fix: properly check for None in tp_size --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/modeling_utils.py | 13 +++++++++---- src/transformers/trainer.py | 7 +++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d68f5f6dfe..a97b446c96 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4472,7 +4472,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH A torch tensor parallel degree. If not provided would default to world size. device_mesh (`torch.distributed.DeviceMesh`, *optional*): A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now. - If provided, it has to contain dimension named `"tp"` which will be used for tensor parallelism + If provided, it has to contain dimension named `"tp"` in case it's > 1 dimensional, this dimension will be used for tensor parallelism offload_folder (`str` or `os.PathLike`, *optional*): If the `device_map` contains any value `"disk"`, the folder where we will offload weights. offload_state_dict (`bool`, *optional*): @@ -4617,10 +4617,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH if device_mesh is None: tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size) else: - # TODO: make device_mesh support multiple dimensions if device_mesh.ndim > 1: - raise ValueError("device_mesh must be 1 dimensional and will be used for TP") - device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"])) + if "tp" not in device_mesh.mesh_dim_names: + raise ValueError( + "When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. " + "Please provide a valid `device_mesh`." + ) + device_mesh = device_mesh["tp"] + tp_size = device_mesh.size() + device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}") if tp_size is None: tp_size = torch.distributed.get_world_size() diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 43cb4b88f2..52dc9c3557 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3953,6 +3953,13 @@ class Trainer: if IS_SAGEMAKER_MP_POST_1_10: # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 Path(os.path.join(output_dir, "user_content.pt")).touch() + # We are in N-D parallelism if we have parallelism_config set, so we check accelerate if we're on a to_save rank + elif getattr(self.accelerator, "parallelism_config", None) is not None: + if self.accelerator.should_save_model: + self._save(output_dir) + # If we drop to here, we're in 1D parallelism, so all ranks need to go to `save_pretrained` + elif (tp_size := getattr(self.model, "_tp_size", 0)) is not None and tp_size > 1: + self._save(output_dir) elif self.is_fsdp_enabled: if ("FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)) and ( version.parse(accelerate_version) > version.parse("0.24.1")