From 1636a7bcb942370bb4098c8e67e4c3d3fd6a1740 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Jun 2025 15:23:52 +0200 Subject: [PATCH] Fixes for Arcee model (#39001) * fix modular * Update modular_arcee.py * fix --- .../models/arcee/configuration_arcee.py | 1 - .../models/arcee/modeling_arcee.py | 88 +++++++------------ .../models/arcee/modular_arcee.py | 44 +--------- 3 files changed, 33 insertions(+), 100 deletions(-) diff --git a/src/transformers/models/arcee/configuration_arcee.py b/src/transformers/models/arcee/configuration_arcee.py index b74dd1a4fe..909783c5d8 100644 --- a/src/transformers/models/arcee/configuration_arcee.py +++ b/src/transformers/models/arcee/configuration_arcee.py @@ -128,7 +128,6 @@ class ArceeConfig(PretrainedConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index e9d59eb4d8..dc8b7880c4 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -51,8 +51,6 @@ logger = logging.get_logger(__name__) class ArceeMLP(nn.Module): - """Arcee MLP with configurable activation function (typically relu2)""" - def __init__(self, config): super().__init__() self.config = config @@ -87,40 +85,6 @@ class ArceeRMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -@auto_docstring -class ArceePreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = ArceeConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["ArceeDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_attention_backend = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, ArceeRMSNorm): - module.weight.data.fill_(1.0) - - class ArceeRotaryEmbedding(nn.Module): def __init__(self, config: ArceeConfig, device=None): super().__init__() @@ -350,15 +314,37 @@ class ArceeDecoderLayer(GradientCheckpointingLayer): return outputs +@auto_docstring +class ArceePreTrainedModel(PreTrainedModel): + config_class = ArceeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ArceeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, ArceeRMSNorm): + module.weight.data.fill_(1.0) + + @auto_docstring class ArceeModel(ArceePreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ArceeDecoderLayer`] - - Args: - config: ArceeConfig - """ - def __init__(self, config: ArceeConfig): super().__init__(config) self.padding_idx = config.pad_token_id @@ -485,10 +471,8 @@ class ArceeModel(ArceePreTrainedModel): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -@auto_docstring +@auto_docstring(checkpoint="arcee-ai/AFM-4.5B") class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin): - """Arcee Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).""" - _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} @@ -598,10 +582,6 @@ class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin): @auto_docstring(checkpoint="arcee-ai/AFM-4.5B") class ArceeForSequenceClassification(ArceePreTrainedModel): - """ - The Arcee Model transformer with a sequence classification head on top (linear layer). - """ - def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels @@ -689,10 +669,6 @@ class ArceeForSequenceClassification(ArceePreTrainedModel): @auto_docstring(checkpoint="arcee-ai/AFM-4.5B") class ArceeForQuestionAnswering(ArceePreTrainedModel): - """ - The Arcee Model transformer with a span classification head on top for extractive question-answering tasks. - """ - base_model_prefix = "transformer" def __init__(self, config): @@ -756,10 +732,6 @@ class ArceeForQuestionAnswering(ArceePreTrainedModel): @auto_docstring(checkpoint="arcee-ai/AFM-4.5B") class ArceeForTokenClassification(ArceePreTrainedModel): - """ - The Arcee Model transformer with a token classification head on top. - """ - def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels diff --git a/src/transformers/models/arcee/modular_arcee.py b/src/transformers/models/arcee/modular_arcee.py index b77906ae3c..7be3b8031a 100644 --- a/src/transformers/models/arcee/modular_arcee.py +++ b/src/transformers/models/arcee/modular_arcee.py @@ -22,8 +22,6 @@ from ..llama.modeling_llama import ( LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, - LlamaModel, - LlamaPreTrainedModel, ) from ..nemotron.modeling_nemotron import NemotronMLP @@ -135,7 +133,6 @@ class ArceeConfig(LlamaConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } @@ -194,61 +191,26 @@ class ArceeConfig(LlamaConfig): class ArceeMLP(NemotronMLP): - """Arcee MLP with configurable activation function (typically relu2)""" - - pass - - -class ArceePreTrainedModel(LlamaPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - pass - - -class ArceeModel(LlamaModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ArceeDecoderLayer`] - - Args: - config: ArceeConfig - """ - pass +@auto_docstring(checkpoint="arcee-ai/AFM-4.5B") class ArceeForCausalLM(LlamaForCausalLM): - """Arcee Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).""" - pass @auto_docstring(checkpoint="arcee-ai/AFM-4.5B") class ArceeForSequenceClassification(LlamaForSequenceClassification): - """ - The Arcee Model transformer with a sequence classification head on top (linear layer). - """ - pass @auto_docstring(checkpoint="arcee-ai/AFM-4.5B") class ArceeForQuestionAnswering(LlamaForQuestionAnswering): - """ - The Arcee Model transformer with a span classification head on top for extractive question-answering tasks. - """ - pass @auto_docstring(checkpoint="arcee-ai/AFM-4.5B") class ArceeForTokenClassification(LlamaForTokenClassification): - """ - The Arcee Model transformer with a token classification head on top. - """ - pass @@ -258,6 +220,6 @@ __all__ = [ "ArceeForQuestionAnswering", "ArceeForSequenceClassification", "ArceeForTokenClassification", - "ArceeModel", - "ArceePreTrainedModel", + "ArceeModel", # noqa: F822 + "ArceePreTrainedModel", # noqa: F822 ]