Add pipeline parallel plan to PretrainedConfig and PreTrainedModel (#36091)
* Add `base_model_pp_plan` to `PretrainedConfig` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add `_pp_plan` to `PreTrainedModel` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add both to Llama for testing Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Fix type error Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Update to suggested schema Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * `_pp_plan` keys are not patterns Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Simplify schema Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Fix typing error Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Update input name for Llama Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Aria Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Bamba Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Cohere 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to diffllama and emu3 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Gemma 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to GLM and GPT NeoX Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Granite and Helium Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Mistral and Mixtral Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to OLMo 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan to Phi and Phi 3 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan for Qwen 2, 2 MoE, 2 VL and 2.5 VL Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add pp plan for Starcoder 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Add enum for accessing inputs and outputs Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Update type hints to use tuples Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * Change outer list to tuple Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --------- Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -74,6 +74,8 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
naming of attributes.
|
naming of attributes.
|
||||||
- **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
|
- **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
|
||||||
parallel plan applied to the sub-module when `model.tensor_parallel` is called.
|
parallel plan applied to the sub-module when `model.tensor_parallel` is called.
|
||||||
|
- **base_model_pp_plan** (`Dict[str, Tuple[List[str]]]`) -- A dict that maps child-modules of a base model to a
|
||||||
|
pipeline parallel plan that enables users to place the child-module on the appropriate device.
|
||||||
|
|
||||||
Common attributes (present in all subclasses):
|
Common attributes (present in all subclasses):
|
||||||
|
|
||||||
@@ -198,6 +200,7 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
is_composition: bool = False
|
is_composition: bool = False
|
||||||
attribute_map: Dict[str, str] = {}
|
attribute_map: Dict[str, str] = {}
|
||||||
base_model_tp_plan: Optional[Dict[str, Any]] = None
|
base_model_tp_plan: Optional[Dict[str, Any]] = None
|
||||||
|
base_model_pp_plan: Optional[Dict[str, Tuple[List[str]]]] = None
|
||||||
_auto_class: Optional[str] = None
|
_auto_class: Optional[str] = None
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
@@ -860,6 +863,9 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
# Do not serialize `base_model_tp_plan` for now
|
# Do not serialize `base_model_tp_plan` for now
|
||||||
if "base_model_tp_plan" in serializable_config_dict:
|
if "base_model_tp_plan" in serializable_config_dict:
|
||||||
del serializable_config_dict["base_model_tp_plan"]
|
del serializable_config_dict["base_model_tp_plan"]
|
||||||
|
# Do not serialize `base_model_pp_plan` for now
|
||||||
|
if "base_model_pp_plan" in serializable_config_dict:
|
||||||
|
del serializable_config_dict["base_model_pp_plan"]
|
||||||
|
|
||||||
return serializable_config_dict
|
return serializable_config_dict
|
||||||
|
|
||||||
@@ -882,6 +888,9 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
# Do not serialize `base_model_tp_plan` for now
|
# Do not serialize `base_model_tp_plan` for now
|
||||||
if "base_model_tp_plan" in output:
|
if "base_model_tp_plan" in output:
|
||||||
del output["base_model_tp_plan"]
|
del output["base_model_tp_plan"]
|
||||||
|
# Do not serialize `base_model_pp_plan` for now
|
||||||
|
if "base_model_pp_plan" in output:
|
||||||
|
del output["base_model_pp_plan"]
|
||||||
|
|
||||||
# Transformers version when serializing the model
|
# Transformers version when serializing the model
|
||||||
output["transformers_version"] = __version__
|
output["transformers_version"] = __version__
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import tempfile
|
|||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
|
||||||
@@ -923,6 +924,11 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
|||||||
return weights_name
|
return weights_name
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineParallel(Enum):
|
||||||
|
inputs: 0
|
||||||
|
outputs: 1
|
||||||
|
|
||||||
|
|
||||||
class ModuleUtilsMixin:
|
class ModuleUtilsMixin:
|
||||||
"""
|
"""
|
||||||
A few utilities for `torch.nn.Modules`, to be used as a mixin.
|
A few utilities for `torch.nn.Modules`, to be used as a mixin.
|
||||||
@@ -1312,6 +1318,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# `config.base_model_tp_plan` during `post_init`.
|
# `config.base_model_tp_plan` during `post_init`.
|
||||||
_tp_plan = None
|
_tp_plan = None
|
||||||
|
|
||||||
|
# A pipeline parallel plan specifying the layers which may not be present
|
||||||
|
# on all ranks when PP is enabled. For top-level models, this attribute is
|
||||||
|
# currently defined in respective model code. For base models, this
|
||||||
|
# attribute comes from `config.base_model_pp_plan` during `post_init`.
|
||||||
|
#
|
||||||
|
# The variable names for the inputs and outputs of the specified layers can
|
||||||
|
# be indexed using the `PipelineParallel` enum as follows:
|
||||||
|
# - `_pp_plan["layers"][PipelineParallel.inputs]`
|
||||||
|
# - `_pp_plan["layers"][PipelineParallel.outputs]`
|
||||||
|
_pp_plan = None
|
||||||
|
|
||||||
# This flag signal that the model can be used as an efficient backend in TGI and vLLM
|
# This flag signal that the model can be used as an efficient backend in TGI and vLLM
|
||||||
# In practice, it means that they support attention interface functions, fully pass the kwargs
|
# In practice, it means that they support attention interface functions, fully pass the kwargs
|
||||||
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
|
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
|
||||||
@@ -1374,6 +1391,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# If current model is a base model, attach `base_model_tp_plan` from config
|
# If current model is a base model, attach `base_model_tp_plan` from config
|
||||||
if self.base_model is self:
|
if self.base_model is self:
|
||||||
self._tp_plan = self.config.base_model_tp_plan
|
self._tp_plan = self.config.base_model_tp_plan
|
||||||
|
# If current model is a base model, attach `base_model_pp_plan` from config
|
||||||
|
if self.base_model is self:
|
||||||
|
self._pp_plan = self.config.base_model_pp_plan
|
||||||
|
|
||||||
def dequantize(self):
|
def dequantize(self):
|
||||||
"""
|
"""
|
||||||
@@ -5196,6 +5216,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# function to every submodule.
|
# function to every submodule.
|
||||||
self.apply(tplize)
|
self.apply(tplize)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_pp_plan(self):
|
||||||
|
if self._pp_plan is not None:
|
||||||
|
return True
|
||||||
|
# Check if base model has PP plan
|
||||||
|
if getattr(self.base_model, "_pp_plan", None) is not None:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loss_function(self):
|
def loss_function(self):
|
||||||
if hasattr(self, "_loss_function"):
|
if hasattr(self, "_loss_function"):
|
||||||
|
|||||||
@@ -144,6 +144,11 @@ class AriaTextConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
base_config_key = "text_config"
|
base_config_key = "text_config"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -1141,6 +1141,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
|||||||
|
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
config_class = AriaTextConfig
|
config_class = AriaTextConfig
|
||||||
|
|
||||||
def __init__(self, config: AriaTextConfig):
|
def __init__(self, config: AriaTextConfig):
|
||||||
|
|||||||
@@ -1446,6 +1446,7 @@ class BambaModel(BambaPreTrainedModel):
|
|||||||
class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
|
class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -148,6 +148,11 @@ class CohereConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -780,6 +780,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
|
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -148,6 +148,11 @@ class Cohere2Config(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -781,6 +781,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
|
class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config: Cohere2Config):
|
def __init__(self, config: Cohere2Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -173,6 +173,11 @@ class Cohere2Config(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1019,6 +1019,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
|
class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -1598,6 +1598,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
|
class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
config_class = Emu3TextConfig
|
config_class = Emu3TextConfig
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
|||||||
@@ -102,6 +102,11 @@ class GemmaConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -752,6 +752,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -126,6 +126,11 @@ class GemmaConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -106,6 +106,11 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -790,6 +790,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||||||
class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -132,6 +132,11 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -93,6 +93,11 @@ class GlmConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
|
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
|
||||||
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
|
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -761,6 +761,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
|
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -137,6 +137,12 @@ class GPTNeoXConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.dense_h_to_4h": "colwise",
|
"layers.*.mlp.dense_h_to_4h": "colwise",
|
||||||
"layers.*.mlp.dense_4h_to_h": "rowwise",
|
"layers.*.mlp.dense_4h_to_h": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_in": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"emb_dropout": (["inputs_embeds"], ["hidden_states"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"final_layer_norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -758,6 +758,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin):
|
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["embed_out.weight"]
|
_tied_weights_keys = ["embed_out.weight"]
|
||||||
_tp_plan = {"embed_out": "colwise_rep"}
|
_tp_plan = {"embed_out": "colwise_rep"}
|
||||||
|
_pp_plan = {"embed_out": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -456,6 +456,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin):
|
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["embed_out.weight"]
|
_tied_weights_keys = ["embed_out.weight"]
|
||||||
_tp_plan = {"embed_out": "colwise_rep"}
|
_tp_plan = {"embed_out": "colwise_rep"}
|
||||||
|
_pp_plan = {"embed_out": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -122,6 +122,11 @@ class GraniteConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -764,6 +764,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin):
|
class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -95,6 +95,11 @@ class HeliumConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -748,6 +748,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
|
class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config: HeliumConfig):
|
def __init__(self, config: HeliumConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -151,6 +151,11 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -750,6 +750,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -107,6 +107,11 @@ class MistralConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -751,6 +751,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
|
class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -119,6 +119,11 @@ class MixtralConfig(PretrainedConfig):
|
|||||||
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
|
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
|
||||||
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
|
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -967,6 +967,7 @@ def load_balancing_loss_func(
|
|||||||
class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
|
class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -115,6 +115,11 @@ class OlmoConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -726,6 +726,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
|
class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -98,6 +98,11 @@ class Olmo2Config(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -727,6 +727,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
|
class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -109,6 +109,11 @@ class Olmo2Config(OlmoConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -146,6 +146,12 @@ class PhiConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.fc1": "colwise",
|
"layers.*.mlp.fc1": "colwise",
|
||||||
"layers.*.mlp.fc2": "rowwise",
|
"layers.*.mlp.fc2": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"embed_dropout": (["inputs_embeds"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"final_layernorm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -724,6 +724,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
|
class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -113,6 +113,11 @@ class Phi3Config(PretrainedConfig):
|
|||||||
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
|
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
|
||||||
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
|
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -821,6 +821,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
|
class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -139,6 +139,11 @@ class Qwen2Config(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -735,6 +735,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
|
class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -184,6 +184,11 @@ class Qwen2_5_VLConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -160,6 +160,11 @@ class Qwen2MoeConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1217,6 +1217,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
|||||||
class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
|
class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -173,6 +173,11 @@ class Qwen2VLConfig(PretrainedConfig):
|
|||||||
"layers.*.mlp.up_proj": "colwise",
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -143,6 +143,11 @@ class Starcoder2Config(PretrainedConfig):
|
|||||||
"layers.*.mlp.c_fc": "colwise",
|
"layers.*.mlp.c_fc": "colwise",
|
||||||
"layers.*.mlp.c_proj": "colwise",
|
"layers.*.mlp.c_proj": "colwise",
|
||||||
}
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -747,6 +747,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|||||||
class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
|
class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user