(Part 2) feat: allow for tp_size attr for tplizing the model (#37054)

* feat: custom tp_size, new transformers tp interface

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: review cmt - error when tp_plan not set for tp_size

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: nit in docs

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

---------

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Matej Sirovatka <54212263+S1ro1@users.noreply.github.com>
This commit is contained in:
Mehant Kammakomati
2025-04-10 21:14:09 +05:30
committed by GitHub
parent dac443414e
commit 7d76876498
7 changed files with 27 additions and 120 deletions

View File

@@ -1788,6 +1788,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# for example.
_tp_plan = None
# tensor parallel degree to which model is sharded to.
_tp_size = None
# A pipeline parallel plan specifying the layers which may not be present
# on all ranks when PP is enabled. For top-level models, this attribute is
# currently defined in respective model code. For base models, this
@@ -3878,6 +3881,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts
`tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
tp_size (`str`, *optional*):
A torch tensor parallel degree. If not provided would default to world size.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*):
@@ -3974,6 +3979,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
generation_config = kwargs.pop("generation_config", None)
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
key_mapping = kwargs.pop("key_mapping", None)
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
@@ -3986,7 +3992,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
raise ValueError(
"`state_dict` cannot be passed together with a model name or a `gguf_file`. Use one of the two loading strategies."
)
if tp_size is not None and tp_plan is None:
raise ValueError("tp_plan has to be set when tp_size is passed.")
if tp_plan is not None and tp_plan != "auto":
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
@@ -4046,9 +4053,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
sys.stderr = open(os.devnull, "w")
# This is the easiest way to dispatch to the current process device
device_map = tp_device
# Assuming sharding the model onto the world
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
# Assuming sharding the model onto the world when tp_size not provided
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
if use_auth_token is not None:
warnings.warn(
@@ -4415,6 +4423,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
weights_only=weights_only,
)
# record tp degree the model sharded to
model._tp_size = tp_size
# make sure token embedding weights are still tied if needed
model.tie_weights()
@@ -4498,7 +4509,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
elif from_flax:
loading_info = None
return model, loading_info
return model
@staticmethod
@@ -5142,6 +5152,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
return True
return False
@property
def tp_size(self):
"""
Returns the model's tensor parallelism degree.
"""
# if None, the model didn't undergo tensor parallel sharding
return self._tp_size
@property
def supports_pp_plan(self):
if self._pp_plan is not None: