(Part 2) feat: allow for tp_size attr for tplizing the model (#37054)
* feat: custom tp_size, new transformers tp interface Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: review cmt - error when tp_plan not set for tp_size Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: nit in docs Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> --------- Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Matej Sirovatka <54212263+S1ro1@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
dac443414e
commit
7d76876498
@@ -1788,6 +1788,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# for example.
|
||||
_tp_plan = None
|
||||
|
||||
# tensor parallel degree to which model is sharded to.
|
||||
_tp_size = None
|
||||
|
||||
# A pipeline parallel plan specifying the layers which may not be present
|
||||
# on all ranks when PP is enabled. For top-level models, this attribute is
|
||||
# currently defined in respective model code. For base models, this
|
||||
@@ -3878,6 +3881,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts
|
||||
`tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with
|
||||
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
|
||||
tp_size (`str`, *optional*):
|
||||
A torch tensor parallel degree. If not provided would default to world size.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||
offload_state_dict (`bool`, *optional*):
|
||||
@@ -3974,6 +3979,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
generation_config = kwargs.pop("generation_config", None)
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
tp_plan = kwargs.pop("tp_plan", None)
|
||||
tp_size = kwargs.pop("tp_size", None)
|
||||
key_mapping = kwargs.pop("key_mapping", None)
|
||||
# Not used anymore -- remove them from the kwargs
|
||||
_ = kwargs.pop("resume_download", None)
|
||||
@@ -3986,7 +3992,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
raise ValueError(
|
||||
"`state_dict` cannot be passed together with a model name or a `gguf_file`. Use one of the two loading strategies."
|
||||
)
|
||||
|
||||
if tp_size is not None and tp_plan is None:
|
||||
raise ValueError("tp_plan has to be set when tp_size is passed.")
|
||||
if tp_plan is not None and tp_plan != "auto":
|
||||
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
|
||||
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
|
||||
@@ -4046,9 +4053,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
# This is the easiest way to dispatch to the current process device
|
||||
device_map = tp_device
|
||||
# Assuming sharding the model onto the world
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
|
||||
|
||||
# Assuming sharding the model onto the world when tp_size not provided
|
||||
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
@@ -4415,6 +4423,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
weights_only=weights_only,
|
||||
)
|
||||
|
||||
# record tp degree the model sharded to
|
||||
model._tp_size = tp_size
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
model.tie_weights()
|
||||
|
||||
@@ -4498,7 +4509,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
elif from_flax:
|
||||
loading_info = None
|
||||
return model, loading_info
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
@@ -5142,6 +5152,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
"""
|
||||
Returns the model's tensor parallelism degree.
|
||||
"""
|
||||
# if None, the model didn't undergo tensor parallel sharding
|
||||
return self._tp_size
|
||||
|
||||
@property
|
||||
def supports_pp_plan(self):
|
||||
if self._pp_plan is not None:
|
||||
|
||||
Reference in New Issue
Block a user