Add pipeline parallel plan to PretrainedConfig and PreTrainedModel (#36091)

* Add `base_model_pp_plan` to `PretrainedConfig`

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add `_pp_plan` to `PreTrainedModel`

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add both to Llama for testing

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Fix type error

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Update to suggested schema

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* `_pp_plan` keys are not patterns

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Simplify schema

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Fix typing error

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Update input name for Llama

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Aria

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Bamba

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Cohere 1 & 2

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to diffllama and emu3

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Gemma 1 & 2

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to GLM and GPT NeoX

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Granite and Helium

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Mistral and Mixtral

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to OLMo 1 & 2

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Phi and Phi 3

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan for Qwen 2, 2 MoE, 2 VL and 2.5 VL

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan for Starcoder 2

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add enum for accessing inputs and outputs

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Update type hints to use tuples

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Change outer list to tuple

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

---------

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-02-12 09:51:48 +00:00
committed by GitHub
parent 11afab19c0
commit f5fff672db
50 changed files with 188 additions and 0 deletions

View File

@@ -74,6 +74,8 @@ class PretrainedConfig(PushToHubMixin):
naming of attributes. naming of attributes.
- **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor - **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
parallel plan applied to the sub-module when `model.tensor_parallel` is called. parallel plan applied to the sub-module when `model.tensor_parallel` is called.
- **base_model_pp_plan** (`Dict[str, Tuple[List[str]]]`) -- A dict that maps child-modules of a base model to a
pipeline parallel plan that enables users to place the child-module on the appropriate device.
Common attributes (present in all subclasses): Common attributes (present in all subclasses):
@@ -198,6 +200,7 @@ class PretrainedConfig(PushToHubMixin):
is_composition: bool = False is_composition: bool = False
attribute_map: Dict[str, str] = {} attribute_map: Dict[str, str] = {}
base_model_tp_plan: Optional[Dict[str, Any]] = None base_model_tp_plan: Optional[Dict[str, Any]] = None
base_model_pp_plan: Optional[Dict[str, Tuple[List[str]]]] = None
_auto_class: Optional[str] = None _auto_class: Optional[str] = None
def __setattr__(self, key, value): def __setattr__(self, key, value):
@@ -860,6 +863,9 @@ class PretrainedConfig(PushToHubMixin):
# Do not serialize `base_model_tp_plan` for now # Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in serializable_config_dict: if "base_model_tp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_tp_plan"] del serializable_config_dict["base_model_tp_plan"]
# Do not serialize `base_model_pp_plan` for now
if "base_model_pp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_pp_plan"]
return serializable_config_dict return serializable_config_dict
@@ -882,6 +888,9 @@ class PretrainedConfig(PushToHubMixin):
# Do not serialize `base_model_tp_plan` for now # Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in output: if "base_model_tp_plan" in output:
del output["base_model_tp_plan"] del output["base_model_tp_plan"]
# Do not serialize `base_model_pp_plan` for now
if "base_model_pp_plan" in output:
del output["base_model_pp_plan"]
# Transformers version when serializing the model # Transformers version when serializing the model
output["transformers_version"] = __version__ output["transformers_version"] = __version__

View File

@@ -28,6 +28,7 @@ import tempfile
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from functools import partial, wraps from functools import partial, wraps
from threading import Thread from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
@@ -923,6 +924,11 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
return weights_name return weights_name
class PipelineParallel(Enum):
inputs: 0
outputs: 1
class ModuleUtilsMixin: class ModuleUtilsMixin:
""" """
A few utilities for `torch.nn.Modules`, to be used as a mixin. A few utilities for `torch.nn.Modules`, to be used as a mixin.
@@ -1312,6 +1318,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# `config.base_model_tp_plan` during `post_init`. # `config.base_model_tp_plan` during `post_init`.
_tp_plan = None _tp_plan = None
# A pipeline parallel plan specifying the layers which may not be present
# on all ranks when PP is enabled. For top-level models, this attribute is
# currently defined in respective model code. For base models, this
# attribute comes from `config.base_model_pp_plan` during `post_init`.
#
# The variable names for the inputs and outputs of the specified layers can
# be indexed using the `PipelineParallel` enum as follows:
# - `_pp_plan["layers"][PipelineParallel.inputs]`
# - `_pp_plan["layers"][PipelineParallel.outputs]`
_pp_plan = None
# This flag signal that the model can be used as an efficient backend in TGI and vLLM # This flag signal that the model can be used as an efficient backend in TGI and vLLM
# In practice, it means that they support attention interface functions, fully pass the kwargs # In practice, it means that they support attention interface functions, fully pass the kwargs
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
@@ -1374,6 +1391,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# If current model is a base model, attach `base_model_tp_plan` from config # If current model is a base model, attach `base_model_tp_plan` from config
if self.base_model is self: if self.base_model is self:
self._tp_plan = self.config.base_model_tp_plan self._tp_plan = self.config.base_model_tp_plan
# If current model is a base model, attach `base_model_pp_plan` from config
if self.base_model is self:
self._pp_plan = self.config.base_model_pp_plan
def dequantize(self): def dequantize(self):
""" """
@@ -5196,6 +5216,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# function to every submodule. # function to every submodule.
self.apply(tplize) self.apply(tplize)
@property
def supports_pp_plan(self):
if self._pp_plan is not None:
return True
# Check if base model has PP plan
if getattr(self.base_model, "_pp_plan", None) is not None:
return True
return False
@property @property
def loss_function(self): def loss_function(self):
if hasattr(self, "_loss_function"): if hasattr(self, "_loss_function"):

View File

@@ -144,6 +144,11 @@ class AriaTextConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
base_config_key = "text_config" base_config_key = "text_config"
def __init__( def __init__(

View File

@@ -1141,6 +1141,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config_class = AriaTextConfig config_class = AriaTextConfig
def __init__(self, config: AriaTextConfig): def __init__(self, config: AriaTextConfig):

View File

@@ -1446,6 +1446,7 @@ class BambaModel(BambaPreTrainedModel):
class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -148,6 +148,11 @@ class CohereConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -780,6 +780,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -148,6 +148,11 @@ class Cohere2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -781,6 +781,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config: Cohere2Config): def __init__(self, config: Cohere2Config):
super().__init__(config) super().__init__(config)

View File

@@ -173,6 +173,11 @@ class Cohere2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -1019,6 +1019,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -1598,6 +1598,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config_class = Emu3TextConfig config_class = Emu3TextConfig
def __init__(self, config): def __init__(self, config):

View File

@@ -102,6 +102,11 @@ class GemmaConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -752,6 +752,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -126,6 +126,11 @@ class GemmaConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -106,6 +106,11 @@ class Gemma2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -790,6 +790,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -132,6 +132,11 @@ class Gemma2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -93,6 +93,11 @@ class GlmConfig(PretrainedConfig):
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -761,6 +761,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -137,6 +137,12 @@ class GPTNeoXConfig(PretrainedConfig):
"layers.*.mlp.dense_h_to_4h": "colwise", "layers.*.mlp.dense_h_to_4h": "colwise",
"layers.*.mlp.dense_4h_to_h": "rowwise", "layers.*.mlp.dense_4h_to_h": "rowwise",
} }
base_model_pp_plan = {
"embed_in": (["input_ids"], ["inputs_embeds"]),
"emb_dropout": (["inputs_embeds"], ["hidden_states"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"final_layer_norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -758,6 +758,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["embed_out.weight"] _tied_weights_keys = ["embed_out.weight"]
_tp_plan = {"embed_out": "colwise_rep"} _tp_plan = {"embed_out": "colwise_rep"}
_pp_plan = {"embed_out": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -456,6 +456,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["embed_out.weight"] _tied_weights_keys = ["embed_out.weight"]
_tp_plan = {"embed_out": "colwise_rep"} _tp_plan = {"embed_out": "colwise_rep"}
_pp_plan = {"embed_out": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -122,6 +122,11 @@ class GraniteConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -764,6 +764,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -95,6 +95,11 @@ class HeliumConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -748,6 +748,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config: HeliumConfig): def __init__(self, config: HeliumConfig):
super().__init__(config) super().__init__(config)

View File

@@ -151,6 +151,11 @@ class LlamaConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -750,6 +750,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -107,6 +107,11 @@ class MistralConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -751,6 +751,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -119,6 +119,11 @@ class MixtralConfig(PretrainedConfig):
"layers.*.block_sparse_moe.experts.*.w2": "rowwise", "layers.*.block_sparse_moe.experts.*.w2": "rowwise",
"layers.*.block_sparse_moe.experts.*.w3": "colwise", "layers.*.block_sparse_moe.experts.*.w3": "colwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -967,6 +967,7 @@ def load_balancing_loss_func(
class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -115,6 +115,11 @@ class OlmoConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -726,6 +726,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -98,6 +98,11 @@ class Olmo2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -727,6 +727,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -109,6 +109,11 @@ class Olmo2Config(OlmoConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -146,6 +146,12 @@ class PhiConfig(PretrainedConfig):
"layers.*.mlp.fc1": "colwise", "layers.*.mlp.fc1": "colwise",
"layers.*.mlp.fc2": "rowwise", "layers.*.mlp.fc2": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"embed_dropout": (["inputs_embeds"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"final_layernorm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -724,6 +724,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -113,6 +113,11 @@ class Phi3Config(PretrainedConfig):
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -821,6 +821,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -139,6 +139,11 @@ class Qwen2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -735,6 +735,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -184,6 +184,11 @@ class Qwen2_5_VLConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -160,6 +160,11 @@ class Qwen2MoeConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -1217,6 +1217,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@@ -173,6 +173,11 @@ class Qwen2VLConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -143,6 +143,11 @@ class Starcoder2Config(PretrainedConfig):
"layers.*.mlp.c_fc": "colwise", "layers.*.mlp.c_fc": "colwise",
"layers.*.mlp.c_proj": "colwise", "layers.*.mlp.c_proj": "colwise",
} }
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__( def __init__(
self, self,

View File

@@ -747,6 +747,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)