fix ndim check of device_mesh for TP (#39538)
This commit is contained in:
@@ -4522,7 +4522,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
|
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
|
||||||
else:
|
else:
|
||||||
# TODO: make device_mesh support multiple dimensions
|
# TODO: make device_mesh support multiple dimensions
|
||||||
if device_mesh.ndim == 1:
|
if device_mesh.ndim != 1:
|
||||||
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
|
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
|
||||||
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
|
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user