Fixes for Arcee model (#39001)

* fix modular

* Update modular_arcee.py

* fix
This commit is contained in:
Cyril Vallez
2025-06-24 15:23:52 +02:00
committed by GitHub
parent 71de20b818
commit 1636a7bcb9
3 changed files with 33 additions and 100 deletions

View File

@@ -128,7 +128,6 @@ class ArceeConfig(PretrainedConfig):
"layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise", "layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }

View File

@@ -51,8 +51,6 @@ logger = logging.get_logger(__name__)
class ArceeMLP(nn.Module): class ArceeMLP(nn.Module):
"""Arcee MLP with configurable activation function (typically relu2)"""
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
@@ -87,40 +85,6 @@ class ArceeRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
@auto_docstring
class ArceePreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ArceeConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ArceeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, ArceeRMSNorm):
module.weight.data.fill_(1.0)
class ArceeRotaryEmbedding(nn.Module): class ArceeRotaryEmbedding(nn.Module):
def __init__(self, config: ArceeConfig, device=None): def __init__(self, config: ArceeConfig, device=None):
super().__init__() super().__init__()
@@ -350,15 +314,37 @@ class ArceeDecoderLayer(GradientCheckpointingLayer):
return outputs return outputs
@auto_docstring
class ArceePreTrainedModel(PreTrainedModel):
config_class = ArceeConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ArceeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, ArceeRMSNorm):
module.weight.data.fill_(1.0)
@auto_docstring @auto_docstring
class ArceeModel(ArceePreTrainedModel): class ArceeModel(ArceePreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ArceeDecoderLayer`]
Args:
config: ArceeConfig
"""
def __init__(self, config: ArceeConfig): def __init__(self, config: ArceeConfig):
super().__init__(config) super().__init__(config)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
@@ -485,10 +471,8 @@ class ArceeModel(ArceePreTrainedModel):
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@auto_docstring @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin): class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
"""Arcee Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings)."""
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
@@ -598,10 +582,6 @@ class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B") @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForSequenceClassification(ArceePreTrainedModel): class ArceeForSequenceClassification(ArceePreTrainedModel):
"""
The Arcee Model transformer with a sequence classification head on top (linear layer).
"""
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
@@ -689,10 +669,6 @@ class ArceeForSequenceClassification(ArceePreTrainedModel):
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B") @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForQuestionAnswering(ArceePreTrainedModel): class ArceeForQuestionAnswering(ArceePreTrainedModel):
"""
The Arcee Model transformer with a span classification head on top for extractive question-answering tasks.
"""
base_model_prefix = "transformer" base_model_prefix = "transformer"
def __init__(self, config): def __init__(self, config):
@@ -756,10 +732,6 @@ class ArceeForQuestionAnswering(ArceePreTrainedModel):
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B") @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForTokenClassification(ArceePreTrainedModel): class ArceeForTokenClassification(ArceePreTrainedModel):
"""
The Arcee Model transformer with a token classification head on top.
"""
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels

View File

@@ -22,8 +22,6 @@ from ..llama.modeling_llama import (
LlamaForQuestionAnswering, LlamaForQuestionAnswering,
LlamaForSequenceClassification, LlamaForSequenceClassification,
LlamaForTokenClassification, LlamaForTokenClassification,
LlamaModel,
LlamaPreTrainedModel,
) )
from ..nemotron.modeling_nemotron import NemotronMLP from ..nemotron.modeling_nemotron import NemotronMLP
@@ -135,7 +133,6 @@ class ArceeConfig(LlamaConfig):
"layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise", "layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise", "layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise", "layers.*.mlp.down_proj": "rowwise",
} }
@@ -194,61 +191,26 @@ class ArceeConfig(LlamaConfig):
class ArceeMLP(NemotronMLP): class ArceeMLP(NemotronMLP):
"""Arcee MLP with configurable activation function (typically relu2)"""
pass
class ArceePreTrainedModel(LlamaPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
pass
class ArceeModel(LlamaModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ArceeDecoderLayer`]
Args:
config: ArceeConfig
"""
pass pass
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForCausalLM(LlamaForCausalLM): class ArceeForCausalLM(LlamaForCausalLM):
"""Arcee Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings)."""
pass pass
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B") @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForSequenceClassification(LlamaForSequenceClassification): class ArceeForSequenceClassification(LlamaForSequenceClassification):
"""
The Arcee Model transformer with a sequence classification head on top (linear layer).
"""
pass pass
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B") @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForQuestionAnswering(LlamaForQuestionAnswering): class ArceeForQuestionAnswering(LlamaForQuestionAnswering):
"""
The Arcee Model transformer with a span classification head on top for extractive question-answering tasks.
"""
pass pass
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B") @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForTokenClassification(LlamaForTokenClassification): class ArceeForTokenClassification(LlamaForTokenClassification):
"""
The Arcee Model transformer with a token classification head on top.
"""
pass pass
@@ -258,6 +220,6 @@ __all__ = [
"ArceeForQuestionAnswering", "ArceeForQuestionAnswering",
"ArceeForSequenceClassification", "ArceeForSequenceClassification",
"ArceeForTokenClassification", "ArceeForTokenClassification",
"ArceeModel", "ArceeModel", # noqa: F822
"ArceePreTrainedModel", "ArceePreTrainedModel", # noqa: F822
] ]