Add pipeline parallel plan to PretrainedConfig and PreTrainedModel (#36091)
* Add `base_model_pp_plan` to `PretrainedConfig` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add `_pp_plan` to `PreTrainedModel` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add both to Llama for testing Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Fix type error Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Update to suggested schema Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * `_pp_plan` keys are not patterns Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Simplify schema Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Fix typing error Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Update input name for Llama Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Aria Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Bamba Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Cohere 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to diffllama and emu3 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Gemma 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to GLM and GPT NeoX Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Granite and Helium Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Mistral and Mixtral Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to OLMo 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Phi and Phi 3 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan for Qwen 2, 2 MoE, 2 VL and 2.5 VL Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan for Starcoder 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add enum for accessing inputs and outputs Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Update type hints to use tuples Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Change outer list to tuple Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --------- Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -74,6 +74,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||
naming of attributes.
|
||||
- **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
|
||||
parallel plan applied to the sub-module when `model.tensor_parallel` is called.
|
||||
- **base_model_pp_plan** (`Dict[str, Tuple[List[str]]]`) -- A dict that maps child-modules of a base model to a
|
||||
pipeline parallel plan that enables users to place the child-module on the appropriate device.
|
||||
|
||||
Common attributes (present in all subclasses):
|
||||
|
||||
@@ -198,6 +200,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
is_composition: bool = False
|
||||
attribute_map: Dict[str, str] = {}
|
||||
base_model_tp_plan: Optional[Dict[str, Any]] = None
|
||||
base_model_pp_plan: Optional[Dict[str, Tuple[List[str]]]] = None
|
||||
_auto_class: Optional[str] = None
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
@@ -860,6 +863,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
# Do not serialize `base_model_tp_plan` for now
|
||||
if "base_model_tp_plan" in serializable_config_dict:
|
||||
del serializable_config_dict["base_model_tp_plan"]
|
||||
# Do not serialize `base_model_pp_plan` for now
|
||||
if "base_model_pp_plan" in serializable_config_dict:
|
||||
del serializable_config_dict["base_model_pp_plan"]
|
||||
|
||||
return serializable_config_dict
|
||||
|
||||
@@ -882,6 +888,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
# Do not serialize `base_model_tp_plan` for now
|
||||
if "base_model_tp_plan" in output:
|
||||
del output["base_model_tp_plan"]
|
||||
# Do not serialize `base_model_pp_plan` for now
|
||||
if "base_model_pp_plan" in output:
|
||||
del output["base_model_pp_plan"]
|
||||
|
||||
# Transformers version when serializing the model
|
||||
output["transformers_version"] = __version__
|
||||
|
||||
@@ -28,6 +28,7 @@ import tempfile
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
|
||||
@@ -923,6 +924,11 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
return weights_name
|
||||
|
||||
|
||||
class PipelineParallel(Enum):
|
||||
inputs: 0
|
||||
outputs: 1
|
||||
|
||||
|
||||
class ModuleUtilsMixin:
|
||||
"""
|
||||
A few utilities for `torch.nn.Modules`, to be used as a mixin.
|
||||
@@ -1312,6 +1318,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# `config.base_model_tp_plan` during `post_init`.
|
||||
_tp_plan = None
|
||||
|
||||
# A pipeline parallel plan specifying the layers which may not be present
|
||||
# on all ranks when PP is enabled. For top-level models, this attribute is
|
||||
# currently defined in respective model code. For base models, this
|
||||
# attribute comes from `config.base_model_pp_plan` during `post_init`.
|
||||
#
|
||||
# The variable names for the inputs and outputs of the specified layers can
|
||||
# be indexed using the `PipelineParallel` enum as follows:
|
||||
# - `_pp_plan["layers"][PipelineParallel.inputs]`
|
||||
# - `_pp_plan["layers"][PipelineParallel.outputs]`
|
||||
_pp_plan = None
|
||||
|
||||
# This flag signal that the model can be used as an efficient backend in TGI and vLLM
|
||||
# In practice, it means that they support attention interface functions, fully pass the kwargs
|
||||
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
|
||||
@@ -1374,6 +1391,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# If current model is a base model, attach `base_model_tp_plan` from config
|
||||
if self.base_model is self:
|
||||
self._tp_plan = self.config.base_model_tp_plan
|
||||
# If current model is a base model, attach `base_model_pp_plan` from config
|
||||
if self.base_model is self:
|
||||
self._pp_plan = self.config.base_model_pp_plan
|
||||
|
||||
def dequantize(self):
|
||||
"""
|
||||
@@ -5196,6 +5216,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# function to every submodule.
|
||||
self.apply(tplize)
|
||||
|
||||
@property
|
||||
def supports_pp_plan(self):
|
||||
if self._pp_plan is not None:
|
||||
return True
|
||||
# Check if base model has PP plan
|
||||
if getattr(self.base_model, "_pp_plan", None) is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def loss_function(self):
|
||||
if hasattr(self, "_loss_function"):
|
||||
|
||||
@@ -144,6 +144,11 @@ class AriaTextConfig(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
base_config_key = "text_config"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -1141,6 +1141,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
config_class = AriaTextConfig
|
||||
|
||||
def __init__(self, config: AriaTextConfig):
|
||||
|
||||
@@ -1446,6 +1446,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -148,6 +148,11 @@ class CohereConfig(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -780,6 +780,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -148,6 +148,11 @@ class Cohere2Config(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -781,6 +781,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config: Cohere2Config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -173,6 +173,11 @@ class Cohere2Config(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1019,6 +1019,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -1598,6 +1598,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
config_class = Emu3TextConfig
|
||||
|
||||
def __init__(self, config):
|
||||
|
||||
@@ -102,6 +102,11 @@ class GemmaConfig(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -752,6 +752,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -126,6 +126,11 @@ class GemmaConfig(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -106,6 +106,11 @@ class Gemma2Config(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -790,6 +790,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -132,6 +132,11 @@ class Gemma2Config(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -93,6 +93,11 @@ class GlmConfig(PretrainedConfig):
|
||||
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
|
||||
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -761,6 +761,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -137,6 +137,12 @@ class GPTNeoXConfig(PretrainedConfig):
|
||||
"layers.*.mlp.dense_h_to_4h": "colwise",
|
||||
"layers.*.mlp.dense_4h_to_h": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_in": (["input_ids"], ["inputs_embeds"]),
|
||||
"emb_dropout": (["inputs_embeds"], ["hidden_states"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"final_layer_norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -758,6 +758,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["embed_out.weight"]
|
||||
_tp_plan = {"embed_out": "colwise_rep"}
|
||||
_pp_plan = {"embed_out": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -456,6 +456,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["embed_out.weight"]
|
||||
_tp_plan = {"embed_out": "colwise_rep"}
|
||||
_pp_plan = {"embed_out": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -122,6 +122,11 @@ class GraniteConfig(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -764,6 +764,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -95,6 +95,11 @@ class HeliumConfig(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -748,6 +748,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config: HeliumConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -151,6 +151,11 @@ class LlamaConfig(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -750,6 +750,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -107,6 +107,11 @@ class MistralConfig(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -751,6 +751,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -119,6 +119,11 @@ class MixtralConfig(PretrainedConfig):
|
||||
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
|
||||
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -967,6 +967,7 @@ def load_balancing_loss_func(
|
||||
class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -115,6 +115,11 @@ class OlmoConfig(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -726,6 +726,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -98,6 +98,11 @@ class Olmo2Config(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -727,6 +727,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -109,6 +109,11 @@ class Olmo2Config(OlmoConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -146,6 +146,12 @@ class PhiConfig(PretrainedConfig):
|
||||
"layers.*.mlp.fc1": "colwise",
|
||||
"layers.*.mlp.fc2": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"embed_dropout": (["inputs_embeds"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"final_layernorm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -724,6 +724,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -113,6 +113,11 @@ class Phi3Config(PretrainedConfig):
|
||||
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
|
||||
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -821,6 +821,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -139,6 +139,11 @@ class Qwen2Config(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -735,6 +735,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -184,6 +184,11 @@ class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -160,6 +160,11 @@ class Qwen2MoeConfig(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1217,6 +1217,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
||||
class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -173,6 +173,11 @@ class Qwen2VLConfig(PretrainedConfig):
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -143,6 +143,11 @@ class Starcoder2Config(PretrainedConfig):
|
||||
"layers.*.mlp.c_fc": "colwise",
|
||||
"layers.*.mlp.c_proj": "colwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -747,6 +747,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
Reference in New Issue
Block a user