Allow device_mesh have multiple dim (#38949)

* Feat: something

* Feat: initial changes

* tmp changes to unblock

* Refactor

* remove todo

* Feat: docstring

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Matej Sirovatka
2025-07-23 14:27:36 +02:00
committed by GitHub
parent 10c990f7e2
commit 82603b6cc2

View File

@@ -4581,6 +4581,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
A torch tensor parallel degree. If not provided would default to world size.
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
If provided, it has to contain dimension named `"tp"` which will be used for tensor parallelism
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*):
@@ -4718,13 +4719,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
# `device_map` pointing to the correct device
if tp_plan is not None:
if device_mesh is None and tp_plan is not None:
if device_mesh is None:
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
else:
# TODO: make device_mesh support multiple dimensions
if device_mesh.ndim != 1:
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
if "tp" not in device_mesh.mesh_dim_names:
raise ValueError(
"When using `tp_plan`, the `device_mesh` must contain a 'tp' dimension. "
"Please provide a valid `device_mesh`."
)
device_mesh = device_mesh["tp"]
tp_size = device_mesh["tp"].size()
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
if tp_size is None:
tp_size = torch.distributed.get_world_size()