From 9682d07f92bffcc2d091a32cfbb3692884e7cacd Mon Sep 17 00:00:00 2001 From: Paul Pak <52512091+paulpak58@users.noreply.github.com> Date: Thu, 10 Jul 2025 08:07:33 -0600 Subject: [PATCH] LFM2 (#39340) * [modeling][lfm2] LFM2 model on 4.53.0 interface * [configuration] hook in LFM2 keys * [modeling][lfm2] update modeling interface for 4.53.1 * [modeling][lfm2] apply mask to hidden conv states * [misc] ruff format/lint * [modeling][lfm2] minor: NotImplemented legacy cache conversion * Create lfm2.md * create nice modular * style * Update modeling_auto.py * clean and start adding tests * style * Update test_modeling_lfm2.py * Update __init__.py * small test model size * config * small fix * fix * remove useless config attrs -> block_dim and conv_dim are hiden_size * fix prepare inputs * fix config * test * typo * skip tests accordingly * config docstrings * add doc to .md * skip config docstring check --------- Co-authored-by: Maxime Labonne <81252890+mlabonne@users.noreply.github.com> Co-authored-by: Cyril Vallez --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/lfm2.md | 84 ++ src/transformers/generation/utils.py | 1 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/bamba/modeling_bamba.py | 2 +- .../models/falcon_h1/modeling_falcon_h1.py | 5 +- .../models/jamba/modeling_jamba.py | 2 +- src/transformers/models/lfm2/__init__.py | 27 + .../models/lfm2/configuration_lfm2.py | 165 ++++ src/transformers/models/lfm2/modeling_lfm2.py | 763 ++++++++++++++++++ src/transformers/models/lfm2/modular_lfm2.py | 500 ++++++++++++ tests/generation/test_utils.py | 1 + tests/models/lfm2/__init__.py | 0 tests/models/lfm2/test_modeling_lfm2.py | 93 +++ utils/check_config_attributes.py | 1 + 17 files changed, 1645 insertions(+), 6 deletions(-) create mode 100644 docs/source/en/model_doc/lfm2.md create mode 100644 src/transformers/models/lfm2/__init__.py create mode 100644 src/transformers/models/lfm2/configuration_lfm2.py create mode 100644 src/transformers/models/lfm2/modeling_lfm2.py create mode 100644 src/transformers/models/lfm2/modular_lfm2.py create mode 100644 tests/models/lfm2/__init__.py create mode 100644 tests/models/lfm2/test_modeling_lfm2.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3e2280ca91..12a46e02c4 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -517,6 +517,8 @@ title: Jukebox - local: model_doc/led title: LED + - local: model_doc/lfm2 + title: LFM2 - local: model_doc/llama title: LLaMA - local: model_doc/llama2 diff --git a/docs/source/en/model_doc/lfm2.md b/docs/source/en/model_doc/lfm2.md new file mode 100644 index 0000000000..c94e421d76 --- /dev/null +++ b/docs/source/en/model_doc/lfm2.md @@ -0,0 +1,84 @@ + + +
+PyTorch +
+ +# LFM2 + +## Overview + +[LFM2](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models) represents a new generation of Liquid Foundation Models developed by [Liquid AI](https://liquid.ai/), specifically designed for edge AI and on-device deployment. + +The models are available in three sizes (350M, 700M, and 1.2B parameters) and are engineered to run efficiently on CPU, GPU, and NPU hardware, making them particularly well-suited for applications requiring low latency, offline operation, and privacy. + +## Architecture + +The architecture consists of 16 blocks total: 10 double-gated short-range convolution blocks and 6 blocks of grouped query attention. This design stems from the concept of dynamical systems, where linear operations are modulated by input-dependent gates, allowing for "liquid" dynamics that can adapt in real-time. The short convolutions are particularly optimized for embedded SoC CPUs, making them ideal for devices that require fast, local inference without relying on cloud connectivity. + +The key architectural innovation of LFM2 lies in its systematic approach to balancing quality, latency, and memory efficiency through our STAR neural architecture search engine. Using STAR, Liquid AI optimized the models for real-world performance on embedded hardware, measuring actual peak memory usage and inference speed on Qualcomm Snapdragon processors. This results in models that achieve 2x faster decode and prefill performance compared to similar-sized models, while maintaining superior benchmark performance across knowledge, mathematics, instruction following, and multilingual tasks. + +## Example + +The following example shows how to generate an answer using the `AutoModelForCausalLM` class. + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Load model and tokenizer +model_id = "LiquidAI/LFM2-1.2B" +model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto", + torch_dtype="bfloat16", +) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +# Generate answer +prompt = "What is C. elegans?" +input_ids = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + return_tensors="pt", + tokenize=True, +) + +output = model.generate( + input_ids, + do_sample=True, + temperature=0.3, + min_p=0.15, + repetition_penalty=1.05, + max_new_tokens=512, +) + +print(tokenizer.decode(output[0], skip_special_tokens=False)) +``` + +## Lfm2Config + +[[autodoc]] Lfm2Config + +## Lfm2Model + +[[autodoc]] Lfm2Model + - forward + +## Lfm2ForCausalLM + +[[autodoc]] Lfm2ForCausalLM + - forward \ No newline at end of file diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6208945434..1daab346d9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1989,6 +1989,7 @@ class GenerationMixin(ContinuousMixin): and "zamba" not in self.__class__.__name__.lower() and "bamba" not in self.__class__.__name__.lower() and "minimax" not in self.__class__.__name__.lower() + and "lfm2" not in self.__class__.__name__.lower() ) def _prepare_cache_for_generation( diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index b0913030e4..d2c59c0e8f 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -168,6 +168,7 @@ if TYPE_CHECKING: from .layoutxlm import * from .led import * from .levit import * + from .lfm2 import * from .lightglue import * from .lilt import * from .llama import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index d097206c12..9eaa3fc669 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -201,6 +201,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str]( ("layoutlmv3", "LayoutLMv3Config"), ("led", "LEDConfig"), ("levit", "LevitConfig"), + ("lfm2", "Lfm2Config"), ("lightglue", "LightGlueConfig"), ("lilt", "LiltConfig"), ("llama", "LlamaConfig"), @@ -591,6 +592,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str]( ("layoutxlm", "LayoutXLM"), ("led", "LED"), ("levit", "LeViT"), + ("lfm2", "Lfm2"), ("lightglue", "LightGlue"), ("lilt", "LiLT"), ("llama", "LLaMA"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index eb46b2d0e4..7b27324511 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -190,6 +190,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("layoutlmv3", "LayoutLMv3Model"), ("led", "LEDModel"), ("levit", "LevitModel"), + ("lfm2", "Lfm2Model"), ("lightglue", "LightGlueForKeypointMatching"), ("lilt", "LiltModel"), ("llama", "LlamaModel"), @@ -614,6 +615,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("helium", "HeliumForCausalLM"), ("jamba", "JambaForCausalLM"), ("jetmoe", "JetMoeForCausalLM"), + ("lfm2", "Lfm2ForCausalLM"), ("llama", "LlamaForCausalLM"), ("llama4", "Llama4ForCausalLM"), ("llama4_text", "Llama4ForCausalLM"), diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index eba5449a39..1210de68df 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -32,7 +32,7 @@ from torch import nn import transformers.models.jamba.modeling_jamba as modeling_jamba from transformers.activations import ACT2FN -from ...cache_utils import Cache # we need __iter__ and __len__ of pkv +from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index c75644036b..20067056ba 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -32,10 +32,7 @@ from torch import nn from transformers.activations import ACT2FN -from ...cache_utils import ( - Cache, - DynamicCache, # we need __iter__ and __len__ of pkv -) +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index c635927c9e..ce2f13eb0e 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -28,7 +28,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available diff --git a/src/transformers/models/lfm2/__init__.py b/src/transformers/models/lfm2/__init__.py new file mode 100644 index 0000000000..239ab87983 --- /dev/null +++ b/src/transformers/models/lfm2/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_lfm2 import * + from .modeling_lfm2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/lfm2/configuration_lfm2.py b/src/transformers/models/lfm2/configuration_lfm2.py new file mode 100644 index 0000000000..ce331a311a --- /dev/null +++ b/src/transformers/models/lfm2/configuration_lfm2.py @@ -0,0 +1,165 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from ...configuration_utils import PretrainedConfig + + +class Lfm2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Lfm2Model`]. It is used to instantiate a LFM2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LFM2-1.2B model. + e.g. [LiquidAI/LFM2-1.2B](https://huggingface.co/LiquidAI/LFM2-1.2B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 65536): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Lfm2Model`] + hidden_size (`int`, *optional*, defaults to 2560): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 12288): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. Lfm2 1 supports up to 2048 tokens, + Lfm2 2 up to 4096, CodeLfm2 up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the conv layers. + conv_L_cache (`int`, *optional*, defaults to 3): + L_cache dim in the conv layers. + block_multiple_of (`int`, *optional*, defaults to 256): + Multiple for the `intermediate_size`. + block_ffn_dim_multiplier (`float`, *optional*, defaults to 1.0): + Multiplier for the `intermediate_size`. + block_auto_adjust_ff_dim (`bool`, *optional*, defaults to `True`): + Whether to adjust the dim of the `intermediate_size`. + full_attn_idxs (`Optional`, *optional*): + Index of the layers which use attention. + layer_types (`Optional`, *optional*): + Type of each layers. + + ```python + >>> from transformers import Lfm2Model, Lfm2Config + + >>> # Initializing a LFM2 model + >>> configuration = Lfm2Config() + + >>> # Initializing a model from the LFM2-1.2B style configuration + >>> model = Lfm2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "lfm2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size: int = 65536, + hidden_size: int = 2560, + intermediate_size: int = 12288, + num_hidden_layers: int = 32, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + max_position_embeddings: int = 128_000, + initializer_range: float = 0.02, + norm_eps: float = 0.00001, + use_cache: bool = True, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = True, + rope_theta: float = 1000000.0, + conv_bias: bool = False, + conv_L_cache: int = 3, + block_multiple_of: int = 256, + block_ffn_dim_multiplier: float = 1.0, + block_auto_adjust_ff_dim: bool = True, + full_attn_idxs: Optional[list[int]] = None, + layer_types: Optional[list[str]] = None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.rope_theta = kwargs.get("theta", rope_theta) # to fit original config keys + self.max_position_embeddings = max_position_embeddings + self.use_cache = use_cache + self.norm_eps = norm_eps + self.initializer_range = initializer_range + + # attn operator config + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + # custom operator config + self.conv_bias = conv_bias + self.conv_L_cache = conv_L_cache + + # MLP config + self.intermediate_size = kwargs.get("block_ff_dim", intermediate_size) # to fit original config keys + self.block_multiple_of = block_multiple_of + self.block_ffn_dim_multiplier = block_ffn_dim_multiplier + self.block_auto_adjust_ff_dim = block_auto_adjust_ff_dim + + self.layer_types = layer_types + if self.layer_types is None: + full_attn_idxs = full_attn_idxs if full_attn_idxs is not None else list(range(num_hidden_layers)) + self.layer_types = ["full_attention" if i in full_attn_idxs else "conv" for i in range(num_hidden_layers)] + + tie_word_embeddings = kwargs.get("tie_embedding", tie_word_embeddings) # to fit original config keys + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Lfm2Config"] diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py new file mode 100644 index 0000000000..049d5073f3 --- /dev/null +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -0,0 +1,763 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lfm2/modular_lfm2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lfm2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import check_model_inputs +from ...utils.import_utils import is_causal_conv1d_available +from .configuration_lfm2 import Lfm2Config + + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_fn, causal_conv1d_update = None, None + + +@use_kernel_forward_from_hub("RMSNorm") +class Lfm2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Lfm2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Lfm2RotaryEmbedding(nn.Module): + def __init__(self, config: Lfm2Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Lfm2MLP(nn.Module): + def __init__(self, config: Lfm2Config): + super().__init__() + intermediate_size = config.intermediate_size + if config.block_auto_adjust_ff_dim: + intermediate_size = int(2 * intermediate_size / 3) + # custom dim factor multiplier + if config.block_ffn_dim_multiplier is not None: + intermediate_size = int(config.block_ffn_dim_multiplier * intermediate_size) + intermediate_size = config.block_multiple_of * ( + (intermediate_size + config.block_multiple_of - 1) // config.block_multiple_of + ) + self.w1 = nn.Linear(config.hidden_size, intermediate_size, bias=False) + self.w3 = nn.Linear(config.hidden_size, intermediate_size, bias=False) + self.w2 = nn.Linear(intermediate_size, config.hidden_size, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class Lfm2HybridConvCache(DynamicCache): + """ + Attention and conv cache for Lfm2. + + It stores the Key and Value states as a list of tensors, one for each layer. + Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`. + Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. + """ + + def __init__( + self, + config: Lfm2Config, + max_batch_size: int, + dtype: torch.dtype = torch.float32, + device: Union[torch.device, str, None] = None, + ): + super().__init__() # initialize key and value cache + self.max_batch_size = max_batch_size + self.layer_types = config.layer_types + self.first_attention_layer = self.layer_types.index("full_attention") + self.conv_L_cache = config.conv_L_cache + self._dtype = dtype + + self.conv_cache: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + + for _ in range(config.num_hidden_layers): + conv_state = torch.zeros( + self.max_batch_size, + config.hidden_size, + self.conv_L_cache, + dtype=self._dtype, + device=device, + ) + torch._dynamo.mark_static_address(conv_state) + self.conv_cache.append(conv_state) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == self.first_attention_layer: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if key_states is not None: + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif ( + not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model + ): # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_cache[layer_idx].device + self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.") + + def reset(self): + for layer_idx in range(len(self.conv_cache)): + # In-place ops prevent breaking the static address + self.conv_cache[layer_idx].zero_() + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Lfm2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Lfm2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Lfm2HybridConvCache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_layernorm(self.q_proj(hidden_states).view(*hidden_shape)).transpose(1, 2) + key_states = self.k_layernorm(self.k_proj(hidden_states).view(*hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + output = self.out_proj(attn_output) + return output, attn_weights + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +kernel_modules = (causal_conv1d_fn, causal_conv1d_update) +is_fast_path_available = all(kernel_modules) + + +class Lfm2ShortConv(nn.Module): + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.L_cache = config.conv_L_cache + self.bias = config.conv_bias + + self.conv = nn.Conv1d( + in_channels=config.hidden_size, + out_channels=config.hidden_size, + kernel_size=self.L_cache, + groups=config.hidden_size, + bias=self.bias, + padding=self.L_cache - 1, + ) + self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias) + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias) + + def cuda_kernels_forward( + self, + x: torch.Tensor, + past_key_value: Optional[Lfm2HybridConvCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + x = apply_mask_to_padding_states(x, attention_mask) + BCx = self.in_proj(x).transpose(-1, -2) + B, C, x = BCx.chunk(3, dim=-2) + + Bx = B * x + + conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) + if past_key_value is not None and cache_position[0] > 0: + conv_out = causal_conv1d_update( + Bx.squeeze(-1), + past_key_value.conv_cache[self.layer_idx], + conv_weights, + self.conv.bias, + None, + ) + conv_out = conv_out.unsqueeze(-1) + else: + if past_key_value is not None: + conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) + past_key_value.conv_cache[self.layer_idx].copy_(conv_state) + + conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None) + + y = C * conv_out + y = self.out_proj(y.transpose(-1, -2).contiguous()) + return y + + def slow_forward( + self, + x: torch.Tensor, + past_key_value: Optional[Lfm2HybridConvCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + seqlen = x.shape[1] + + x = apply_mask_to_padding_states(x, attention_mask) + BCx = self.in_proj(x).transpose(-1, -2) + B, C, x = BCx.chunk(3, dim=-2) + + Bx = B * x + + if past_key_value is not None and cache_position[0] > 0: + conv_state = past_key_value.conv_cache[self.layer_idx] + cache_position = cache_position.clamp(0, self.L_cache - 1) + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype) + past_key_value.conv_cache[self.layer_idx].copy_(conv_state) + conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) + if self.bias: + conv_out += self.conv.bias + + conv_out = conv_out.unsqueeze(-1) + else: + if past_key_value is not None: + conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) + past_key_value.conv_cache[self.layer_idx].copy_(conv_state) + + conv_out = self.conv(Bx)[..., :seqlen] + + y = C * conv_out + y = y.transpose(-1, -2).contiguous() + y = self.out_proj(y) + return y + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: Optional[Lfm2HybridConvCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in hidden_states.device.type and not torch._dynamo.is_compiling(): + return self.cuda_kernels_forward(hidden_states, past_key_value, cache_position, attention_mask) + return self.slow_forward(hidden_states, past_key_value, cache_position, attention_mask) + + +class Lfm2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Lfm2Config, layer_idx: int): + super().__init__() + self.is_attention_layer = config.layer_types[layer_idx] == "full_attention" + + if self.is_attention_layer: + self.self_attn = Lfm2Attention(config, layer_idx) + else: + self.conv = Lfm2ShortConv(config, layer_idx) + self.feed_forward = Lfm2MLP(config) + self.operator_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[tuple[torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + if self.is_attention_layer: + hidden_states, _ = self.self_attn( + hidden_states=self.operator_norm(hidden_states), + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cache_position=cache_position, + **kwargs, + ) + else: + hidden_states = self.conv( + hidden_states=self.operator_norm(hidden_states), + past_key_value=past_key_value, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = hidden_states + residual + hidden_states = hidden_states + self.feed_forward(self.ffn_norm(hidden_states)) + + return hidden_states + + +@auto_docstring +class Lfm2PreTrainedModel(PreTrainedModel): + config_class = Lfm2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Lfm2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_flash_attn_3 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = False + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Lfm2DecoderLayer, + "attentions": Lfm2Attention, + } + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Lfm2RMSNorm): + module.weight.data.fill_(1.0) + + +@auto_docstring +class Lfm2Model(Lfm2PreTrainedModel): + def __init__(self, config: Lfm2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Lfm2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.rotary_emb = Lfm2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.pos_emb = Lfm2RotaryEmbedding(config) + self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Lfm2HybridConvCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + batch_size = inputs_embeds.shape[0] + past_key_values = Lfm2HybridConvCache( + config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.pos_emb(hidden_states, position_ids) + + # decoder layers + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.embedding_norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Lfm2ForCausalLM(Lfm2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Lfm2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Lfm2ForCausalLM + + >>> model = Lfm2ForCausalLM.from_pretrained("meta-lfm2/Lfm2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-lfm2/Lfm2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["Lfm2ForCausalLM", "Lfm2Model", "Lfm2PreTrainedModel"] diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py new file mode 100644 index 0000000000..e0d617daf6 --- /dev/null +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -0,0 +1,500 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...cache_utils import DynamicCache +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, logging +from ...utils.import_utils import is_causal_conv1d_available +from ..bamba.modeling_bamba import apply_mask_to_padding_states +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from .configuration_lfm2 import Lfm2Config + + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_fn, causal_conv1d_update = None, None + + +kernel_modules = (causal_conv1d_fn, causal_conv1d_update) +is_fast_path_available = all(kernel_modules) + + +logger = logging.get_logger(__name__) + + +class Lfm2RMSNorm(LlamaRMSNorm): + pass + + +class Lfm2RotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class Lfm2MLP(nn.Module): + def __init__(self, config: Lfm2Config): + super().__init__() + intermediate_size = config.intermediate_size + if config.block_auto_adjust_ff_dim: + intermediate_size = int(2 * intermediate_size / 3) + # custom dim factor multiplier + if config.block_ffn_dim_multiplier is not None: + intermediate_size = int(config.block_ffn_dim_multiplier * intermediate_size) + intermediate_size = config.block_multiple_of * ( + (intermediate_size + config.block_multiple_of - 1) // config.block_multiple_of + ) + self.w1 = nn.Linear(config.hidden_size, intermediate_size, bias=False) + self.w3 = nn.Linear(config.hidden_size, intermediate_size, bias=False) + self.w2 = nn.Linear(intermediate_size, config.hidden_size, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class Lfm2HybridConvCache(DynamicCache): + """ + Attention and conv cache for Lfm2. + + It stores the Key and Value states as a list of tensors, one for each layer. + Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`. + Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. + """ + + def __init__( + self, + config: Lfm2Config, + max_batch_size: int, + dtype: torch.dtype = torch.float32, + device: Union[torch.device, str, None] = None, + ): + super().__init__() # initialize key and value cache + self.max_batch_size = max_batch_size + self.layer_types = config.layer_types + self.first_attention_layer = self.layer_types.index("full_attention") + self.conv_L_cache = config.conv_L_cache + self._dtype = dtype + + self.conv_cache: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + + for _ in range(config.num_hidden_layers): + conv_state = torch.zeros( + self.max_batch_size, + config.hidden_size, + self.conv_L_cache, + dtype=self._dtype, + device=device, + ) + torch._dynamo.mark_static_address(conv_state) + self.conv_cache.append(conv_state) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == self.first_attention_layer: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if key_states is not None: + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif ( + not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model + ): # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_cache[layer_idx].device + self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.") + + def reset(self): + for layer_idx in range(len(self.conv_cache)): + # In-place ops prevent breaking the static address + self.conv_cache[layer_idx].zero_() + + +class Lfm2Attention(LlamaAttention): + def __init__(self, config: Lfm2Config, layer_idx: int): + super().__init__(config, layer_idx) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) + del self.o_proj + del self.attention_dropout + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Lfm2HybridConvCache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_layernorm(self.q_proj(hidden_states).view(*hidden_shape)).transpose(1, 2) + key_states = self.k_layernorm(self.k_proj(hidden_states).view(*hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + output = self.out_proj(attn_output) + return output, attn_weights + + +class Lfm2ShortConv(nn.Module): + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.L_cache = config.conv_L_cache + self.bias = config.conv_bias + + self.conv = nn.Conv1d( + in_channels=config.hidden_size, + out_channels=config.hidden_size, + kernel_size=self.L_cache, + groups=config.hidden_size, + bias=self.bias, + padding=self.L_cache - 1, + ) + self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias) + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias) + + def cuda_kernels_forward( + self, + x: torch.Tensor, + past_key_value: Optional[Lfm2HybridConvCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + x = apply_mask_to_padding_states(x, attention_mask) + BCx = self.in_proj(x).transpose(-1, -2) + B, C, x = BCx.chunk(3, dim=-2) + + Bx = B * x + + conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) + if past_key_value is not None and cache_position[0] > 0: + conv_out = causal_conv1d_update( + Bx.squeeze(-1), + past_key_value.conv_cache[self.layer_idx], + conv_weights, + self.conv.bias, + None, + ) + conv_out = conv_out.unsqueeze(-1) + else: + if past_key_value is not None: + conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) + past_key_value.conv_cache[self.layer_idx].copy_(conv_state) + + conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None) + + y = C * conv_out + y = self.out_proj(y.transpose(-1, -2).contiguous()) + return y + + def slow_forward( + self, + x: torch.Tensor, + past_key_value: Optional[Lfm2HybridConvCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + seqlen = x.shape[1] + + x = apply_mask_to_padding_states(x, attention_mask) + BCx = self.in_proj(x).transpose(-1, -2) + B, C, x = BCx.chunk(3, dim=-2) + + Bx = B * x + + if past_key_value is not None and cache_position[0] > 0: + conv_state = past_key_value.conv_cache[self.layer_idx] + cache_position = cache_position.clamp(0, self.L_cache - 1) + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype) + past_key_value.conv_cache[self.layer_idx].copy_(conv_state) + conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) + if self.bias: + conv_out += self.conv.bias + + conv_out = conv_out.unsqueeze(-1) + else: + if past_key_value is not None: + conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) + past_key_value.conv_cache[self.layer_idx].copy_(conv_state) + + conv_out = self.conv(Bx)[..., :seqlen] + + y = C * conv_out + y = y.transpose(-1, -2).contiguous() + y = self.out_proj(y) + return y + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: Optional[Lfm2HybridConvCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in hidden_states.device.type and not torch._dynamo.is_compiling(): + return self.cuda_kernels_forward(hidden_states, past_key_value, cache_position, attention_mask) + return self.slow_forward(hidden_states, past_key_value, cache_position, attention_mask) + + +class Lfm2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Lfm2Config, layer_idx: int): + super().__init__() + self.is_attention_layer = config.layer_types[layer_idx] == "full_attention" + + if self.is_attention_layer: + self.self_attn = Lfm2Attention(config, layer_idx) + else: + self.conv = Lfm2ShortConv(config, layer_idx) + self.feed_forward = Lfm2MLP(config) + self.operator_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[tuple[torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + if self.is_attention_layer: + hidden_states, _ = self.self_attn( + hidden_states=self.operator_norm(hidden_states), + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cache_position=cache_position, + **kwargs, + ) + else: + hidden_states = self.conv( + hidden_states=self.operator_norm(hidden_states), + past_key_value=past_key_value, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = hidden_states + residual + hidden_states = hidden_states + self.feed_forward(self.ffn_norm(hidden_states)) + + return hidden_states + + +class Lfm2PreTrainedModel(LlamaPreTrainedModel): + _supports_static_cache = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Lfm2RMSNorm): + module.weight.data.fill_(1.0) + + +class Lfm2Model(LlamaModel): + def __init__(self, config: Lfm2Config): + super().__init__(config) + self.pos_emb = Lfm2RotaryEmbedding(config) + self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) + del self.norm + del self.rotary_emv + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Lfm2HybridConvCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + batch_size = inputs_embeds.shape[0] + past_key_values = Lfm2HybridConvCache( + config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.pos_emb(hidden_states, position_ids) + + # decoder layers + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.embedding_norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class Lfm2ForCausalLM(LlamaForCausalLM): + pass + + +__all__ = ["Lfm2ForCausalLM", "Lfm2Model", "Lfm2PreTrainedModel"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e01b1ac50b..6a82319cf9 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2503,6 +2503,7 @@ class GenerationTesterMixin: "xlnet", "zamba", "zamba2", + "lfm2", ) has_standard_cache = not any( model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache diff --git a/tests/models/lfm2/__init__.py b/tests/models/lfm2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/lfm2/test_modeling_lfm2.py b/tests/models/lfm2/test_modeling_lfm2.py new file mode 100644 index 0000000000..173e5e101d --- /dev/null +++ b/tests/models/lfm2/test_modeling_lfm2.py @@ -0,0 +1,93 @@ +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch LLaMA model.""" + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import ( + require_read_token, + require_torch, + require_torch_accelerator, + slow, +) + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +if is_torch_available(): + from transformers import Lfm2Config, Lfm2ForCausalLM, Lfm2Model + + +class Lfm2ModelTester(CausalLMModelTester): + if is_torch_available(): + config_class = Lfm2Config + base_model_class = Lfm2Model + causal_lm_class = Lfm2ForCausalLM + + def __init__( + self, + parent, + layer_types=["full_attention", "conv"], + ): + super().__init__(parent) + self.layer_types = layer_types + + +@require_torch +class Lfm2ModelTest(CausalLMModelTest, unittest.TestCase): + all_model_classes = (Lfm2Model, Lfm2ForCausalLM) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": Lfm2Model, + "text-generation": Lfm2ForCausalLM, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = False + model_tester_class = Lfm2ModelTester + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = Lfm2ForCausalLM if is_torch_available() else None + + @unittest.skip( + "Lfm2 alternates between attention and conv layers, so attention are only returned for attention layers" + ) + def test_attention_outputs(self): + pass + + @unittest.skip("Lfm2 has a special cache format as it alternates between attention and conv layers") + def test_past_key_values_format(self): + pass + + @unittest.skip("Lfm2 has a special cache format which is not compatible with contrastive search") + def test_contrastive_generate(self): + pass + + @unittest.skip("Lfm2 has a special cache format which is not compatible with contrastive search") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("Lfm2 has a special cache format which is not compatible with contrastive search") + def test_contrastive_generate_low_memory(self): + pass + + +@require_torch_accelerator +@require_read_token +@slow +class Lfm2IntegrationTest(unittest.TestCase): + pass diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 8058558b40..3795270baf 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -32,6 +32,7 @@ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS) CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING SPECIAL_CASES_TO_ALLOW = { + "Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"], # used internally during generation to provide the custom logit processors with their necessary information "DiaConfig": [ "delay_pattern",