Add pipeline parallel plan to PretrainedConfig and PreTrainedModel (#36091)
* Add `base_model_pp_plan` to `PretrainedConfig` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add `_pp_plan` to `PreTrainedModel` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add both to Llama for testing Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Fix type error Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Update to suggested schema Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * `_pp_plan` keys are not patterns Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Simplify schema Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Fix typing error Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Update input name for Llama Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Aria Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Bamba Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Cohere 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to diffllama and emu3 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Gemma 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to GLM and GPT NeoX Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Granite and Helium Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Mistral and Mixtral Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to OLMo 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Phi and Phi 3 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan for Qwen 2, 2 MoE, 2 VL and 2.5 VL Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan for Starcoder 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add enum for accessing inputs and outputs Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Update type hints to use tuples Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Change outer list to tuple Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --------- Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -144,6 +144,11 @@ class AriaTextConfig(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
base_config_key = "text_config"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -1141,6 +1141,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
config_class = AriaTextConfig
|
||||
|
||||
def __init__(self, config: AriaTextConfig):
|
||||
|
||||
Reference in New Issue
Block a user