fix ndim check of device_mesh for TP (#39538)

This commit is contained in:
Wing Lian
2025-07-21 09:09:33 -04:00
committed by GitHub
parent 1aa7256f01
commit 4b4f04fcca

View File

@@ -4522,7 +4522,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None) tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
else: else:
# TODO: make device_mesh support multiple dimensions # TODO: make device_mesh support multiple dimensions
if device_mesh.ndim == 1: if device_mesh.ndim != 1:
raise ValueError("device_mesh must be 1 dimensional and will be used for TP") raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"])) device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))