From f5fff672db338f8143181b2d8b7612060e14a7f3 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 12 Feb 2025 09:51:48 +0000 Subject: [PATCH] Add pipeline parallel plan to `PretrainedConfig` and `PreTrainedModel` (#36091) * Add `base_model_pp_plan` to `PretrainedConfig` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add `_pp_plan` to `PreTrainedModel` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add both to Llama for testing Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Fix type error Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Update to suggested schema Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * `_pp_plan` keys are not patterns Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Simplify schema Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Fix typing error Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Update input name for Llama Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Aria Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Bamba Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Cohere 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to diffllama and emu3 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Gemma 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to GLM and GPT NeoX Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Granite and Helium Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Mistral and Mixtral Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to OLMo 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Phi and Phi 3 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan for Qwen 2, 2 MoE, 2 VL and 2.5 VL Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan for Starcoder 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add enum for accessing inputs and outputs Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Update type hints to use tuples Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Change outer list to tuple Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --------- Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/configuration_utils.py | 9 ++++++ src/transformers/modeling_utils.py | 29 +++++++++++++++++++ .../models/aria/configuration_aria.py | 5 ++++ src/transformers/models/aria/modeling_aria.py | 1 + .../models/bamba/modeling_bamba.py | 1 + .../models/cohere/configuration_cohere.py | 5 ++++ .../models/cohere/modeling_cohere.py | 1 + .../models/cohere2/configuration_cohere2.py | 5 ++++ .../models/cohere2/modeling_cohere2.py | 1 + .../models/cohere2/modular_cohere2.py | 5 ++++ .../models/diffllama/modeling_diffllama.py | 1 + src/transformers/models/emu3/modeling_emu3.py | 1 + .../models/gemma/configuration_gemma.py | 5 ++++ .../models/gemma/modeling_gemma.py | 1 + .../models/gemma/modular_gemma.py | 5 ++++ .../models/gemma2/configuration_gemma2.py | 5 ++++ .../models/gemma2/modeling_gemma2.py | 1 + .../models/gemma2/modular_gemma2.py | 5 ++++ .../models/glm/configuration_glm.py | 5 ++++ src/transformers/models/glm/modeling_glm.py | 1 + .../models/gpt_neox/configuration_gpt_neox.py | 6 ++++ .../models/gpt_neox/modeling_gpt_neox.py | 1 + .../models/gpt_neox/modular_gpt_neox.py | 1 + .../models/granite/configuration_granite.py | 5 ++++ .../models/granite/modeling_granite.py | 1 + .../models/helium/configuration_helium.py | 5 ++++ .../models/helium/modeling_helium.py | 1 + .../models/llama/configuration_llama.py | 5 ++++ .../models/llama/modeling_llama.py | 1 + .../models/mistral/configuration_mistral.py | 5 ++++ .../models/mistral/modeling_mistral.py | 1 + .../models/mixtral/configuration_mixtral.py | 5 ++++ .../models/mixtral/modeling_mixtral.py | 1 + .../models/olmo/configuration_olmo.py | 5 ++++ src/transformers/models/olmo/modeling_olmo.py | 1 + .../models/olmo2/configuration_olmo2.py | 5 ++++ .../models/olmo2/modeling_olmo2.py | 1 + .../models/olmo2/modular_olmo2.py | 5 ++++ .../models/phi/configuration_phi.py | 6 ++++ src/transformers/models/phi/modeling_phi.py | 1 + .../models/phi3/configuration_phi3.py | 5 ++++ src/transformers/models/phi3/modeling_phi3.py | 1 + .../models/qwen2/configuration_qwen2.py | 5 ++++ .../models/qwen2/modeling_qwen2.py | 1 + .../qwen2_5_vl/configuration_qwen2_5_vl.py | 5 ++++ .../qwen2_moe/configuration_qwen2_moe.py | 5 ++++ .../models/qwen2_moe/modeling_qwen2_moe.py | 1 + .../models/qwen2_vl/configuration_qwen2_vl.py | 5 ++++ .../starcoder2/configuration_starcoder2.py | 5 ++++ .../models/starcoder2/modeling_starcoder2.py | 1 + 50 files changed, 188 insertions(+) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index dfb64fcd08..581032ef7d 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -74,6 +74,8 @@ class PretrainedConfig(PushToHubMixin): naming of attributes. - **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor parallel plan applied to the sub-module when `model.tensor_parallel` is called. + - **base_model_pp_plan** (`Dict[str, Tuple[List[str]]]`) -- A dict that maps child-modules of a base model to a + pipeline parallel plan that enables users to place the child-module on the appropriate device. Common attributes (present in all subclasses): @@ -198,6 +200,7 @@ class PretrainedConfig(PushToHubMixin): is_composition: bool = False attribute_map: Dict[str, str] = {} base_model_tp_plan: Optional[Dict[str, Any]] = None + base_model_pp_plan: Optional[Dict[str, Tuple[List[str]]]] = None _auto_class: Optional[str] = None def __setattr__(self, key, value): @@ -860,6 +863,9 @@ class PretrainedConfig(PushToHubMixin): # Do not serialize `base_model_tp_plan` for now if "base_model_tp_plan" in serializable_config_dict: del serializable_config_dict["base_model_tp_plan"] + # Do not serialize `base_model_pp_plan` for now + if "base_model_pp_plan" in serializable_config_dict: + del serializable_config_dict["base_model_pp_plan"] return serializable_config_dict @@ -882,6 +888,9 @@ class PretrainedConfig(PushToHubMixin): # Do not serialize `base_model_tp_plan` for now if "base_model_tp_plan" in output: del output["base_model_tp_plan"] + # Do not serialize `base_model_pp_plan` for now + if "base_model_pp_plan" in output: + del output["base_model_pp_plan"] # Transformers version when serializing the model output["transformers_version"] = __version__ diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e292b1061a..13c8719b36 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -28,6 +28,7 @@ import tempfile import warnings from contextlib import contextmanager from dataclasses import dataclass +from enum import Enum from functools import partial, wraps from threading import Thread from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union @@ -923,6 +924,11 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: return weights_name +class PipelineParallel(Enum): + inputs: 0 + outputs: 1 + + class ModuleUtilsMixin: """ A few utilities for `torch.nn.Modules`, to be used as a mixin. @@ -1312,6 +1318,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # `config.base_model_tp_plan` during `post_init`. _tp_plan = None + # A pipeline parallel plan specifying the layers which may not be present + # on all ranks when PP is enabled. For top-level models, this attribute is + # currently defined in respective model code. For base models, this + # attribute comes from `config.base_model_pp_plan` during `post_init`. + # + # The variable names for the inputs and outputs of the specified layers can + # be indexed using the `PipelineParallel` enum as follows: + # - `_pp_plan["layers"][PipelineParallel.inputs]` + # - `_pp_plan["layers"][PipelineParallel.outputs]` + _pp_plan = None + # This flag signal that the model can be used as an efficient backend in TGI and vLLM # In practice, it means that they support attention interface functions, fully pass the kwargs # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan @@ -1374,6 +1391,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # If current model is a base model, attach `base_model_tp_plan` from config if self.base_model is self: self._tp_plan = self.config.base_model_tp_plan + # If current model is a base model, attach `base_model_pp_plan` from config + if self.base_model is self: + self._pp_plan = self.config.base_model_pp_plan def dequantize(self): """ @@ -5196,6 +5216,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # function to every submodule. self.apply(tplize) + @property + def supports_pp_plan(self): + if self._pp_plan is not None: + return True + # Check if base model has PP plan + if getattr(self.base_model, "_pp_plan", None) is not None: + return True + return False + @property def loss_function(self): if hasattr(self, "_loss_function"): diff --git a/src/transformers/models/aria/configuration_aria.py b/src/transformers/models/aria/configuration_aria.py index ff34d59f5d..fed90c86b4 100644 --- a/src/transformers/models/aria/configuration_aria.py +++ b/src/transformers/models/aria/configuration_aria.py @@ -144,6 +144,11 @@ class AriaTextConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } base_config_key = "text_config" def __init__( diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index ee692c9616..dacc92b795 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1141,6 +1141,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = AriaTextConfig def __init__(self, config: AriaTextConfig): diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 41ba1c5b26..6fdce41e5a 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1446,6 +1446,7 @@ class BambaModel(BambaPreTrainedModel): class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/cohere/configuration_cohere.py b/src/transformers/models/cohere/configuration_cohere.py index dc9ca2cb4d..eeeb236428 100644 --- a/src/transformers/models/cohere/configuration_cohere.py +++ b/src/transformers/models/cohere/configuration_cohere.py @@ -148,6 +148,11 @@ class CohereConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 90b4e6dc63..5101a0f9e0 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -780,6 +780,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/cohere2/configuration_cohere2.py b/src/transformers/models/cohere2/configuration_cohere2.py index 88d3265ead..c792ab3f82 100644 --- a/src/transformers/models/cohere2/configuration_cohere2.py +++ b/src/transformers/models/cohere2/configuration_cohere2.py @@ -148,6 +148,11 @@ class Cohere2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index e900413740..df0cb24d79 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -781,6 +781,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config: Cohere2Config): super().__init__(config) diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index f24e1378ec..979b5abc26 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -173,6 +173,11 @@ class Cohere2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 38d9d3ce00..301668d21a 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -1019,6 +1019,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 38b285be73..ef086ab12e 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1598,6 +1598,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = Emu3TextConfig def __init__(self, config): diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index b8470e92fb..2aeb200580 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -102,6 +102,11 @@ class GemmaConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 047516a4b1..59b7dc3dc3 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -752,6 +752,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 540ce2b87c..dc8ced15f9 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -126,6 +126,11 @@ class GemmaConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 3dfbd6a107..c9e66f8bea 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -106,6 +106,11 @@ class Gemma2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e21d0b656a..d55fafe056 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -790,6 +790,7 @@ class Gemma2Model(Gemma2PreTrainedModel): class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 351f083f81..76123af3ec 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -132,6 +132,11 @@ class Gemma2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/glm/configuration_glm.py b/src/transformers/models/glm/configuration_glm.py index 553b71cf23..f9a3ab53a9 100644 --- a/src/transformers/models/glm/configuration_glm.py +++ b/src/transformers/models/glm/configuration_glm.py @@ -93,6 +93,11 @@ class GlmConfig(PretrainedConfig): "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 0e73af6b15..54c138212e 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -761,6 +761,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index e570662c10..cea854eabb 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -137,6 +137,12 @@ class GPTNeoXConfig(PretrainedConfig): "layers.*.mlp.dense_h_to_4h": "colwise", "layers.*.mlp.dense_4h_to_h": "rowwise", } + base_model_pp_plan = { + "embed_in": (["input_ids"], ["inputs_embeds"]), + "emb_dropout": (["inputs_embeds"], ["hidden_states"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "final_layer_norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index d83ee58af5..efb2982431 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -758,6 +758,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] _tp_plan = {"embed_out": "colwise_rep"} + _pp_plan = {"embed_out": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 295882a9ee..3a7cc49542 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -456,6 +456,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] _tp_plan = {"embed_out": "colwise_rep"} + _pp_plan = {"embed_out": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/granite/configuration_granite.py b/src/transformers/models/granite/configuration_granite.py index 404d60ca32..fc651a94e1 100644 --- a/src/transformers/models/granite/configuration_granite.py +++ b/src/transformers/models/granite/configuration_granite.py @@ -122,6 +122,11 @@ class GraniteConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 52cdc96e64..85c8e97c77 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -764,6 +764,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/helium/configuration_helium.py b/src/transformers/models/helium/configuration_helium.py index 73c2638a88..7b27c6e54b 100644 --- a/src/transformers/models/helium/configuration_helium.py +++ b/src/transformers/models/helium/configuration_helium.py @@ -95,6 +95,11 @@ class HeliumConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 3c17e18e4c..86635f2d72 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -748,6 +748,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config: HeliumConfig): super().__init__(config) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 646c06bdc4..066534f109 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -151,6 +151,11 @@ class LlamaConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 566fa57413..a06084e825 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -750,6 +750,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/mistral/configuration_mistral.py b/src/transformers/models/mistral/configuration_mistral.py index c4b874f270..3a237bc734 100644 --- a/src/transformers/models/mistral/configuration_mistral.py +++ b/src/transformers/models/mistral/configuration_mistral.py @@ -107,6 +107,11 @@ class MistralConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a6d9a54efc..92e555b3d7 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -751,6 +751,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/mixtral/configuration_mixtral.py b/src/transformers/models/mixtral/configuration_mixtral.py index c3f7ec8e4c..d9b02e10fc 100644 --- a/src/transformers/models/mixtral/configuration_mixtral.py +++ b/src/transformers/models/mixtral/configuration_mixtral.py @@ -119,6 +119,11 @@ class MixtralConfig(PretrainedConfig): "layers.*.block_sparse_moe.experts.*.w2": "rowwise", "layers.*.block_sparse_moe.experts.*.w3": "colwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 251187677f..0835e33722 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -967,6 +967,7 @@ def load_balancing_loss_func( class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/olmo/configuration_olmo.py b/src/transformers/models/olmo/configuration_olmo.py index d80910e845..ded0bf4f01 100644 --- a/src/transformers/models/olmo/configuration_olmo.py +++ b/src/transformers/models/olmo/configuration_olmo.py @@ -115,6 +115,11 @@ class OlmoConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index ef3e10582f..37d15475be 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -726,6 +726,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/olmo2/configuration_olmo2.py b/src/transformers/models/olmo2/configuration_olmo2.py index ce434f5416..222c8e1791 100644 --- a/src/transformers/models/olmo2/configuration_olmo2.py +++ b/src/transformers/models/olmo2/configuration_olmo2.py @@ -98,6 +98,11 @@ class Olmo2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 561b7fdf08..40c912ef1a 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -727,6 +727,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 04c7f0f486..bc5a9b89d5 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -109,6 +109,11 @@ class Olmo2Config(OlmoConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/phi/configuration_phi.py b/src/transformers/models/phi/configuration_phi.py index 2733d77ff6..06e5cbec2e 100644 --- a/src/transformers/models/phi/configuration_phi.py +++ b/src/transformers/models/phi/configuration_phi.py @@ -146,6 +146,12 @@ class PhiConfig(PretrainedConfig): "layers.*.mlp.fc1": "colwise", "layers.*.mlp.fc2": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "embed_dropout": (["inputs_embeds"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "final_layernorm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index a62249460e..8ab41d2a0c 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -724,6 +724,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/phi3/configuration_phi3.py b/src/transformers/models/phi3/configuration_phi3.py index 6fe6e1cdfc..a6b7ec9baf 100644 --- a/src/transformers/models/phi3/configuration_phi3.py +++ b/src/transformers/models/phi3/configuration_phi3.py @@ -113,6 +113,11 @@ class Phi3Config(PretrainedConfig): "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index ca6992d377..2595278048 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -821,6 +821,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/qwen2/configuration_qwen2.py b/src/transformers/models/qwen2/configuration_qwen2.py index 16ce924b9f..16979865e4 100644 --- a/src/transformers/models/qwen2/configuration_qwen2.py +++ b/src/transformers/models/qwen2/configuration_qwen2.py @@ -139,6 +139,11 @@ class Qwen2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8310379d83..91eac84ffc 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -735,6 +735,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index a1cf06d94e..b2bf37ba0c 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -184,6 +184,11 @@ class Qwen2_5_VLConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py index ac6e8ae17a..a52b4204a6 100644 --- a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py @@ -160,6 +160,11 @@ class Qwen2MoeConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index bfb4e81d3e..8157408f42 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1217,6 +1217,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index 49a0836cf9..710738e396 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -173,6 +173,11 @@ class Qwen2VLConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/starcoder2/configuration_starcoder2.py b/src/transformers/models/starcoder2/configuration_starcoder2.py index 7f21d1f12d..b617a1cad8 100644 --- a/src/transformers/models/starcoder2/configuration_starcoder2.py +++ b/src/transformers/models/starcoder2/configuration_starcoder2.py @@ -143,6 +143,11 @@ class Starcoder2Config(PretrainedConfig): "layers.*.mlp.c_fc": "colwise", "layers.*.mlp.c_proj": "colwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 9314f05b49..f176d5311d 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -747,6 +747,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config)