Fixes for Arcee model (#39001)
* fix modular * Update modular_arcee.py * fix
This commit is contained in:
@@ -128,7 +128,6 @@ class ArceeConfig(PretrainedConfig):
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
@@ -51,8 +51,6 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ArceeMLP(nn.Module):
|
||||
"""Arcee MLP with configurable activation function (typically relu2)"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -87,40 +85,6 @@ class ArceeRMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class ArceePreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = ArceeConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ArceeDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, ArceeRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class ArceeRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: ArceeConfig, device=None):
|
||||
super().__init__()
|
||||
@@ -350,15 +314,37 @@ class ArceeDecoderLayer(GradientCheckpointingLayer):
|
||||
return outputs
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class ArceePreTrainedModel(PreTrainedModel):
|
||||
config_class = ArceeConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ArceeDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, ArceeRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class ArceeModel(ArceePreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ArceeDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: ArceeConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: ArceeConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
@@ -485,10 +471,8 @@ class ArceeModel(ArceePreTrainedModel):
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
||||
class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
|
||||
"""Arcee Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings)."""
|
||||
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
@@ -598,10 +582,6 @@ class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
|
||||
|
||||
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
||||
class ArceeForSequenceClassification(ArceePreTrainedModel):
|
||||
"""
|
||||
The Arcee Model transformer with a sequence classification head on top (linear layer).
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@@ -689,10 +669,6 @@ class ArceeForSequenceClassification(ArceePreTrainedModel):
|
||||
|
||||
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
||||
class ArceeForQuestionAnswering(ArceePreTrainedModel):
|
||||
"""
|
||||
The Arcee Model transformer with a span classification head on top for extractive question-answering tasks.
|
||||
"""
|
||||
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
def __init__(self, config):
|
||||
@@ -756,10 +732,6 @@ class ArceeForQuestionAnswering(ArceePreTrainedModel):
|
||||
|
||||
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
||||
class ArceeForTokenClassification(ArceePreTrainedModel):
|
||||
"""
|
||||
The Arcee Model transformer with a token classification head on top.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
@@ -22,8 +22,6 @@ from ..llama.modeling_llama import (
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
)
|
||||
from ..nemotron.modeling_nemotron import NemotronMLP
|
||||
|
||||
@@ -135,7 +133,6 @@ class ArceeConfig(LlamaConfig):
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
@@ -194,61 +191,26 @@ class ArceeConfig(LlamaConfig):
|
||||
|
||||
|
||||
class ArceeMLP(NemotronMLP):
|
||||
"""Arcee MLP with configurable activation function (typically relu2)"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ArceePreTrainedModel(LlamaPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ArceeModel(LlamaModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ArceeDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: ArceeConfig
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
||||
class ArceeForCausalLM(LlamaForCausalLM):
|
||||
"""Arcee Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
||||
class ArceeForSequenceClassification(LlamaForSequenceClassification):
|
||||
"""
|
||||
The Arcee Model transformer with a sequence classification head on top (linear layer).
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
||||
class ArceeForQuestionAnswering(LlamaForQuestionAnswering):
|
||||
"""
|
||||
The Arcee Model transformer with a span classification head on top for extractive question-answering tasks.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
||||
class ArceeForTokenClassification(LlamaForTokenClassification):
|
||||
"""
|
||||
The Arcee Model transformer with a token classification head on top.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -258,6 +220,6 @@ __all__ = [
|
||||
"ArceeForQuestionAnswering",
|
||||
"ArceeForSequenceClassification",
|
||||
"ArceeForTokenClassification",
|
||||
"ArceeModel",
|
||||
"ArceePreTrainedModel",
|
||||
"ArceeModel", # noqa: F822
|
||||
"ArceePreTrainedModel", # noqa: F822
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user