PATCH: add back n-dim device-mesh + fix tp trainer saving (#39693)
* Feat: something * Feat: initial changes * tmp changes to unblock * Refactor * remove todo * Feat: docstring * Fix: saving of distributed model in trainer * Fix: distributed saving with trainer * Feat: add pure tp saving * Only require tp dim if ndim > 1 * Fix: default to None * Fix: better comments/errors * Fix: properly check tp_size attribute * Fix: properly check for None in tp_size --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -4472,7 +4472,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
A torch tensor parallel degree. If not provided would default to world size.
|
||||
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
|
||||
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
|
||||
If provided, it has to contain dimension named `"tp"` which will be used for tensor parallelism
|
||||
If provided, it has to contain dimension named `"tp"` in case it's > 1 dimensional, this dimension will be used for tensor parallelism
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||
offload_state_dict (`bool`, *optional*):
|
||||
@@ -4617,10 +4617,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
if device_mesh is None:
|
||||
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size)
|
||||
else:
|
||||
# TODO: make device_mesh support multiple dimensions
|
||||
if device_mesh.ndim > 1:
|
||||
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
|
||||
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
|
||||
if "tp" not in device_mesh.mesh_dim_names:
|
||||
raise ValueError(
|
||||
"When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
|
||||
"Please provide a valid `device_mesh`."
|
||||
)
|
||||
device_mesh = device_mesh["tp"]
|
||||
tp_size = device_mesh.size()
|
||||
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
|
||||
|
||||
if tp_size is None:
|
||||
tp_size = torch.distributed.get_world_size()
|
||||
|
||||
Reference in New Issue
Block a user