From 4b4f04fccaaa3020c5462cf31d286d83fbfc6d38 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 21 Jul 2025 09:09:33 -0400 Subject: [PATCH] fix ndim check of device_mesh for TP (#39538) --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 61110e926d..e20f4c0fe7 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4522,7 +4522,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None) else: # TODO: make device_mesh support multiple dimensions - if device_mesh.ndim == 1: + if device_mesh.ndim != 1: raise ValueError("device_mesh must be 1 dimensional and will be used for TP") device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))