* [modeling][lfm2] LFM2 model on 4.53.0 interface

* [configuration] hook in LFM2 keys

* [modeling][lfm2] update modeling interface for 4.53.1

* [modeling][lfm2] apply mask to hidden conv states

* [misc] ruff format/lint

* [modeling][lfm2] minor: NotImplemented legacy cache conversion

* Create lfm2.md

* create nice modular

* style

* Update modeling_auto.py

* clean and start adding tests

* style

* Update test_modeling_lfm2.py

* Update __init__.py

* small test model size

* config

* small fix

* fix

* remove useless config attrs -> block_dim and conv_dim are hiden_size

* fix prepare inputs

* fix config

* test

* typo

* skip tests accordingly

* config docstrings

* add doc to .md

* skip config docstring check

---------

Co-authored-by: Maxime Labonne <81252890+mlabonne@users.noreply.github.com>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
Paul Pak
2025-07-10 08:07:33 -06:00
committed by GitHub
parent 38c3931362
commit 9682d07f92
17 changed files with 1645 additions and 6 deletions

View File

@@ -517,6 +517,8 @@
title: Jukebox title: Jukebox
- local: model_doc/led - local: model_doc/led
title: LED title: LED
- local: model_doc/lfm2
title: LFM2
- local: model_doc/llama - local: model_doc/llama
title: LLaMA title: LLaMA
- local: model_doc/llama2 - local: model_doc/llama2

View File

@@ -0,0 +1,84 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
# LFM2
## Overview
[LFM2](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models) represents a new generation of Liquid Foundation Models developed by [Liquid AI](https://liquid.ai/), specifically designed for edge AI and on-device deployment.
The models are available in three sizes (350M, 700M, and 1.2B parameters) and are engineered to run efficiently on CPU, GPU, and NPU hardware, making them particularly well-suited for applications requiring low latency, offline operation, and privacy.
## Architecture
The architecture consists of 16 blocks total: 10 double-gated short-range convolution blocks and 6 blocks of grouped query attention. This design stems from the concept of dynamical systems, where linear operations are modulated by input-dependent gates, allowing for "liquid" dynamics that can adapt in real-time. The short convolutions are particularly optimized for embedded SoC CPUs, making them ideal for devices that require fast, local inference without relying on cloud connectivity.
The key architectural innovation of LFM2 lies in its systematic approach to balancing quality, latency, and memory efficiency through our STAR neural architecture search engine. Using STAR, Liquid AI optimized the models for real-world performance on embedded hardware, measuring actual peak memory usage and inference speed on Qualcomm Snapdragon processors. This results in models that achieve 2x faster decode and prefill performance compared to similar-sized models, while maintaining superior benchmark performance across knowledge, mathematics, instruction following, and multilingual tasks.
## Example
The following example shows how to generate an answer using the `AutoModelForCausalLM` class.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model and tokenizer
model_id = "LiquidAI/LFM2-1.2B"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype="bfloat16",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Generate answer
prompt = "What is C. elegans?"
input_ids = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="pt",
tokenize=True,
)
output = model.generate(
input_ids,
do_sample=True,
temperature=0.3,
min_p=0.15,
repetition_penalty=1.05,
max_new_tokens=512,
)
print(tokenizer.decode(output[0], skip_special_tokens=False))
```
## Lfm2Config
[[autodoc]] Lfm2Config
## Lfm2Model
[[autodoc]] Lfm2Model
- forward
## Lfm2ForCausalLM
[[autodoc]] Lfm2ForCausalLM
- forward

View File

@@ -1989,6 +1989,7 @@ class GenerationMixin(ContinuousMixin):
and "zamba" not in self.__class__.__name__.lower() and "zamba" not in self.__class__.__name__.lower()
and "bamba" not in self.__class__.__name__.lower() and "bamba" not in self.__class__.__name__.lower()
and "minimax" not in self.__class__.__name__.lower() and "minimax" not in self.__class__.__name__.lower()
and "lfm2" not in self.__class__.__name__.lower()
) )
def _prepare_cache_for_generation( def _prepare_cache_for_generation(

View File

@@ -168,6 +168,7 @@ if TYPE_CHECKING:
from .layoutxlm import * from .layoutxlm import *
from .led import * from .led import *
from .levit import * from .levit import *
from .lfm2 import *
from .lightglue import * from .lightglue import *
from .lilt import * from .lilt import *
from .llama import * from .llama import *

View File

@@ -201,6 +201,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("layoutlmv3", "LayoutLMv3Config"), ("layoutlmv3", "LayoutLMv3Config"),
("led", "LEDConfig"), ("led", "LEDConfig"),
("levit", "LevitConfig"), ("levit", "LevitConfig"),
("lfm2", "Lfm2Config"),
("lightglue", "LightGlueConfig"), ("lightglue", "LightGlueConfig"),
("lilt", "LiltConfig"), ("lilt", "LiltConfig"),
("llama", "LlamaConfig"), ("llama", "LlamaConfig"),
@@ -591,6 +592,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("layoutxlm", "LayoutXLM"), ("layoutxlm", "LayoutXLM"),
("led", "LED"), ("led", "LED"),
("levit", "LeViT"), ("levit", "LeViT"),
("lfm2", "Lfm2"),
("lightglue", "LightGlue"), ("lightglue", "LightGlue"),
("lilt", "LiLT"), ("lilt", "LiLT"),
("llama", "LLaMA"), ("llama", "LLaMA"),

View File

@@ -190,6 +190,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("layoutlmv3", "LayoutLMv3Model"), ("layoutlmv3", "LayoutLMv3Model"),
("led", "LEDModel"), ("led", "LEDModel"),
("levit", "LevitModel"), ("levit", "LevitModel"),
("lfm2", "Lfm2Model"),
("lightglue", "LightGlueForKeypointMatching"), ("lightglue", "LightGlueForKeypointMatching"),
("lilt", "LiltModel"), ("lilt", "LiltModel"),
("llama", "LlamaModel"), ("llama", "LlamaModel"),
@@ -614,6 +615,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("helium", "HeliumForCausalLM"), ("helium", "HeliumForCausalLM"),
("jamba", "JambaForCausalLM"), ("jamba", "JambaForCausalLM"),
("jetmoe", "JetMoeForCausalLM"), ("jetmoe", "JetMoeForCausalLM"),
("lfm2", "Lfm2ForCausalLM"),
("llama", "LlamaForCausalLM"), ("llama", "LlamaForCausalLM"),
("llama4", "Llama4ForCausalLM"), ("llama4", "Llama4ForCausalLM"),
("llama4_text", "Llama4ForCausalLM"), ("llama4_text", "Llama4ForCausalLM"),

View File

@@ -32,7 +32,7 @@ from torch import nn
import transformers.models.jamba.modeling_jamba as modeling_jamba import transformers.models.jamba.modeling_jamba as modeling_jamba
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from ...cache_utils import Cache # we need __iter__ and __len__ of pkv from ...cache_utils import Cache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter

View File

@@ -32,10 +32,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from ...cache_utils import ( from ...cache_utils import Cache, DynamicCache
Cache,
DynamicCache, # we need __iter__ and __len__ of pkv
)
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter

View File

@@ -28,7 +28,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available

View File

@@ -0,0 +1,27 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_lfm2 import *
from .modeling_lfm2 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@@ -0,0 +1,165 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from ...configuration_utils import PretrainedConfig
class Lfm2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Lfm2Model`]. It is used to instantiate a LFM2
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LFM2-1.2B model.
e.g. [LiquidAI/LFM2-1.2B](https://huggingface.co/LiquidAI/LFM2-1.2B)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 65536):
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Lfm2Model`]
hidden_size (`int`, *optional*, defaults to 2560):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 12288):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
`num_attention_heads`.
max_position_embeddings (`int`, *optional*, defaults to 128000):
The maximum sequence length that this model might ever be used with. Lfm2 1 supports up to 2048 tokens,
Lfm2 2 up to 4096, CodeLfm2 up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
conv_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the conv layers.
conv_L_cache (`int`, *optional*, defaults to 3):
L_cache dim in the conv layers.
block_multiple_of (`int`, *optional*, defaults to 256):
Multiple for the `intermediate_size`.
block_ffn_dim_multiplier (`float`, *optional*, defaults to 1.0):
Multiplier for the `intermediate_size`.
block_auto_adjust_ff_dim (`bool`, *optional*, defaults to `True`):
Whether to adjust the dim of the `intermediate_size`.
full_attn_idxs (`Optional`, *optional*):
Index of the layers which use attention.
layer_types (`Optional`, *optional*):
Type of each layers.
```python
>>> from transformers import Lfm2Model, Lfm2Config
>>> # Initializing a LFM2 model
>>> configuration = Lfm2Config()
>>> # Initializing a model from the LFM2-1.2B style configuration
>>> model = Lfm2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "lfm2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size: int = 65536,
hidden_size: int = 2560,
intermediate_size: int = 12288,
num_hidden_layers: int = 32,
num_attention_heads: int = 32,
num_key_value_heads: int = 8,
max_position_embeddings: int = 128_000,
initializer_range: float = 0.02,
norm_eps: float = 0.00001,
use_cache: bool = True,
pad_token_id: int = 0,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = True,
rope_theta: float = 1000000.0,
conv_bias: bool = False,
conv_L_cache: int = 3,
block_multiple_of: int = 256,
block_ffn_dim_multiplier: float = 1.0,
block_auto_adjust_ff_dim: bool = True,
full_attn_idxs: Optional[list[int]] = None,
layer_types: Optional[list[str]] = None,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.rope_theta = kwargs.get("theta", rope_theta) # to fit original config keys
self.max_position_embeddings = max_position_embeddings
self.use_cache = use_cache
self.norm_eps = norm_eps
self.initializer_range = initializer_range
# attn operator config
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
# custom operator config
self.conv_bias = conv_bias
self.conv_L_cache = conv_L_cache
# MLP config
self.intermediate_size = kwargs.get("block_ff_dim", intermediate_size) # to fit original config keys
self.block_multiple_of = block_multiple_of
self.block_ffn_dim_multiplier = block_ffn_dim_multiplier
self.block_auto_adjust_ff_dim = block_auto_adjust_ff_dim
self.layer_types = layer_types
if self.layer_types is None:
full_attn_idxs = full_attn_idxs if full_attn_idxs is not None else list(range(num_hidden_layers))
self.layer_types = ["full_attention" if i in full_attn_idxs else "conv" for i in range(num_hidden_layers)]
tie_word_embeddings = kwargs.get("tie_embedding", tie_word_embeddings) # to fit original config keys
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
__all__ = ["Lfm2Config"]

View File

@@ -0,0 +1,763 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/lfm2/modular_lfm2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_lfm2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.import_utils import is_causal_conv1d_available
from .configuration_lfm2 import Lfm2Config
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_fn, causal_conv1d_update = None, None
@use_kernel_forward_from_hub("RMSNorm")
class Lfm2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Lfm2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Lfm2RotaryEmbedding(nn.Module):
def __init__(self, config: Lfm2Config, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Lfm2MLP(nn.Module):
def __init__(self, config: Lfm2Config):
super().__init__()
intermediate_size = config.intermediate_size
if config.block_auto_adjust_ff_dim:
intermediate_size = int(2 * intermediate_size / 3)
# custom dim factor multiplier
if config.block_ffn_dim_multiplier is not None:
intermediate_size = int(config.block_ffn_dim_multiplier * intermediate_size)
intermediate_size = config.block_multiple_of * (
(intermediate_size + config.block_multiple_of - 1) // config.block_multiple_of
)
self.w1 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
self.w3 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
self.w2 = nn.Linear(intermediate_size, config.hidden_size, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Lfm2HybridConvCache(DynamicCache):
"""
Attention and conv cache for Lfm2.
It stores the Key and Value states as a list of tensors, one for each layer.
Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`.
Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`.
"""
def __init__(
self,
config: Lfm2Config,
max_batch_size: int,
dtype: torch.dtype = torch.float32,
device: Union[torch.device, str, None] = None,
):
super().__init__() # initialize key and value cache
self.max_batch_size = max_batch_size
self.layer_types = config.layer_types
self.first_attention_layer = self.layer_types.index("full_attention")
self.conv_L_cache = config.conv_L_cache
self._dtype = dtype
self.conv_cache: list[torch.Tensor] = []
device = torch.device(device) if device is not None else None
for _ in range(config.num_hidden_layers):
conv_state = torch.zeros(
self.max_batch_size,
config.hidden_size,
self.conv_L_cache,
dtype=self._dtype,
device=device,
)
torch._dynamo.mark_static_address(conv_state)
self.conv_cache.append(conv_state)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == self.first_attention_layer:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if key_states is not None:
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append(torch.tensor([]))
self.value_cache.append(torch.tensor([]))
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif (
not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
): # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.conv_cache[layer_idx].device
self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# take any layer that contains cache and not empty tensor
layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx
if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0:
return 0
return self.key_cache[layer_idx].shape[-2]
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]:
raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.")
@classmethod
def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.")
def reset(self):
for layer_idx in range(len(self.conv_cache)):
# In-place ops prevent breaking the static address
self.conv_cache[layer_idx].zero_()
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Lfm2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Lfm2Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.is_causal = True
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Lfm2HybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_layernorm(self.q_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
key_states = self.k_layernorm(self.k_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
output = self.out_proj(attn_output)
return output, attn_weights
def apply_mask_to_padding_states(hidden_states, attention_mask):
"""
Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
"""
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return hidden_states
kernel_modules = (causal_conv1d_fn, causal_conv1d_update)
is_fast_path_available = all(kernel_modules)
class Lfm2ShortConv(nn.Module):
def __init__(
self,
config: Lfm2Config,
layer_idx: int,
):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.L_cache = config.conv_L_cache
self.bias = config.conv_bias
self.conv = nn.Conv1d(
in_channels=config.hidden_size,
out_channels=config.hidden_size,
kernel_size=self.L_cache,
groups=config.hidden_size,
bias=self.bias,
padding=self.L_cache - 1,
)
self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias)
def cuda_kernels_forward(
self,
x: torch.Tensor,
past_key_value: Optional[Lfm2HybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
x = apply_mask_to_padding_states(x, attention_mask)
BCx = self.in_proj(x).transpose(-1, -2)
B, C, x = BCx.chunk(3, dim=-2)
Bx = B * x
conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
if past_key_value is not None and cache_position[0] > 0:
conv_out = causal_conv1d_update(
Bx.squeeze(-1),
past_key_value.conv_cache[self.layer_idx],
conv_weights,
self.conv.bias,
None,
)
conv_out = conv_out.unsqueeze(-1)
else:
if past_key_value is not None:
conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
past_key_value.conv_cache[self.layer_idx].copy_(conv_state)
conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None)
y = C * conv_out
y = self.out_proj(y.transpose(-1, -2).contiguous())
return y
def slow_forward(
self,
x: torch.Tensor,
past_key_value: Optional[Lfm2HybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
seqlen = x.shape[1]
x = apply_mask_to_padding_states(x, attention_mask)
BCx = self.in_proj(x).transpose(-1, -2)
B, C, x = BCx.chunk(3, dim=-2)
Bx = B * x
if past_key_value is not None and cache_position[0] > 0:
conv_state = past_key_value.conv_cache[self.layer_idx]
cache_position = cache_position.clamp(0, self.L_cache - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype)
past_key_value.conv_cache[self.layer_idx].copy_(conv_state)
conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1)
if self.bias:
conv_out += self.conv.bias
conv_out = conv_out.unsqueeze(-1)
else:
if past_key_value is not None:
conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
past_key_value.conv_cache[self.layer_idx].copy_(conv_state)
conv_out = self.conv(Bx)[..., :seqlen]
y = C * conv_out
y = y.transpose(-1, -2).contiguous()
y = self.out_proj(y)
return y
def forward(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Lfm2HybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
if is_fast_path_available and "cuda" in hidden_states.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, past_key_value, cache_position, attention_mask)
return self.slow_forward(hidden_states, past_key_value, cache_position, attention_mask)
class Lfm2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Lfm2Config, layer_idx: int):
super().__init__()
self.is_attention_layer = config.layer_types[layer_idx] == "full_attention"
if self.is_attention_layer:
self.self_attn = Lfm2Attention(config, layer_idx)
else:
self.conv = Lfm2ShortConv(config, layer_idx)
self.feed_forward = Lfm2MLP(config)
self.operator_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
self.ffn_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[tuple[torch.Tensor]] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
if self.is_attention_layer:
hidden_states, _ = self.self_attn(
hidden_states=self.operator_norm(hidden_states),
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
cache_position=cache_position,
**kwargs,
)
else:
hidden_states = self.conv(
hidden_states=self.operator_norm(hidden_states),
past_key_value=past_key_value,
cache_position=cache_position,
attention_mask=attention_mask,
)
hidden_states = hidden_states + residual
hidden_states = hidden_states + self.feed_forward(self.ffn_norm(hidden_states))
return hidden_states
@auto_docstring
class Lfm2PreTrainedModel(PreTrainedModel):
config_class = Lfm2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Lfm2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_flash_attn_3 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = False
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Lfm2DecoderLayer,
"attentions": Lfm2Attention,
}
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Lfm2RMSNorm):
module.weight.data.fill_(1.0)
@auto_docstring
class Lfm2Model(Lfm2PreTrainedModel):
def __init__(self, config: Lfm2Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Lfm2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.rotary_emb = Lfm2RotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.pos_emb = Lfm2RotaryEmbedding(config)
self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Lfm2HybridConvCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
batch_size = inputs_embeds.shape[0]
past_key_values = Lfm2HybridConvCache(
config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.pos_emb(hidden_states, position_ids)
# decoder layers
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.embedding_norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
@auto_docstring
class Lfm2ForCausalLM(Lfm2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = Lfm2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoTokenizer, Lfm2ForCausalLM
>>> model = Lfm2ForCausalLM.from_pretrained("meta-lfm2/Lfm2-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-lfm2/Lfm2-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = ["Lfm2ForCausalLM", "Lfm2Model", "Lfm2PreTrainedModel"]

View File

@@ -0,0 +1,500 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...cache_utils import DynamicCache
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, logging
from ...utils.import_utils import is_causal_conv1d_available
from ..bamba.modeling_bamba import apply_mask_to_padding_states
from ..llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
LlamaModel,
LlamaPreTrainedModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
from .configuration_lfm2 import Lfm2Config
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_fn, causal_conv1d_update = None, None
kernel_modules = (causal_conv1d_fn, causal_conv1d_update)
is_fast_path_available = all(kernel_modules)
logger = logging.get_logger(__name__)
class Lfm2RMSNorm(LlamaRMSNorm):
pass
class Lfm2RotaryEmbedding(LlamaRotaryEmbedding):
pass
class Lfm2MLP(nn.Module):
def __init__(self, config: Lfm2Config):
super().__init__()
intermediate_size = config.intermediate_size
if config.block_auto_adjust_ff_dim:
intermediate_size = int(2 * intermediate_size / 3)
# custom dim factor multiplier
if config.block_ffn_dim_multiplier is not None:
intermediate_size = int(config.block_ffn_dim_multiplier * intermediate_size)
intermediate_size = config.block_multiple_of * (
(intermediate_size + config.block_multiple_of - 1) // config.block_multiple_of
)
self.w1 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
self.w3 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
self.w2 = nn.Linear(intermediate_size, config.hidden_size, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Lfm2HybridConvCache(DynamicCache):
"""
Attention and conv cache for Lfm2.
It stores the Key and Value states as a list of tensors, one for each layer.
Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`.
Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`.
"""
def __init__(
self,
config: Lfm2Config,
max_batch_size: int,
dtype: torch.dtype = torch.float32,
device: Union[torch.device, str, None] = None,
):
super().__init__() # initialize key and value cache
self.max_batch_size = max_batch_size
self.layer_types = config.layer_types
self.first_attention_layer = self.layer_types.index("full_attention")
self.conv_L_cache = config.conv_L_cache
self._dtype = dtype
self.conv_cache: list[torch.Tensor] = []
device = torch.device(device) if device is not None else None
for _ in range(config.num_hidden_layers):
conv_state = torch.zeros(
self.max_batch_size,
config.hidden_size,
self.conv_L_cache,
dtype=self._dtype,
device=device,
)
torch._dynamo.mark_static_address(conv_state)
self.conv_cache.append(conv_state)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == self.first_attention_layer:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if key_states is not None:
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append(torch.tensor([]))
self.value_cache.append(torch.tensor([]))
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif (
not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
): # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.conv_cache[layer_idx].device
self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# take any layer that contains cache and not empty tensor
layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx
if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0:
return 0
return self.key_cache[layer_idx].shape[-2]
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]:
raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.")
@classmethod
def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.")
def reset(self):
for layer_idx in range(len(self.conv_cache)):
# In-place ops prevent breaking the static address
self.conv_cache[layer_idx].zero_()
class Lfm2Attention(LlamaAttention):
def __init__(self, config: Lfm2Config, layer_idx: int):
super().__init__(config, layer_idx)
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
del self.o_proj
del self.attention_dropout
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Lfm2HybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_layernorm(self.q_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
key_states = self.k_layernorm(self.k_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
output = self.out_proj(attn_output)
return output, attn_weights
class Lfm2ShortConv(nn.Module):
def __init__(
self,
config: Lfm2Config,
layer_idx: int,
):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.L_cache = config.conv_L_cache
self.bias = config.conv_bias
self.conv = nn.Conv1d(
in_channels=config.hidden_size,
out_channels=config.hidden_size,
kernel_size=self.L_cache,
groups=config.hidden_size,
bias=self.bias,
padding=self.L_cache - 1,
)
self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias)
def cuda_kernels_forward(
self,
x: torch.Tensor,
past_key_value: Optional[Lfm2HybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
x = apply_mask_to_padding_states(x, attention_mask)
BCx = self.in_proj(x).transpose(-1, -2)
B, C, x = BCx.chunk(3, dim=-2)
Bx = B * x
conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
if past_key_value is not None and cache_position[0] > 0:
conv_out = causal_conv1d_update(
Bx.squeeze(-1),
past_key_value.conv_cache[self.layer_idx],
conv_weights,
self.conv.bias,
None,
)
conv_out = conv_out.unsqueeze(-1)
else:
if past_key_value is not None:
conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
past_key_value.conv_cache[self.layer_idx].copy_(conv_state)
conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None)
y = C * conv_out
y = self.out_proj(y.transpose(-1, -2).contiguous())
return y
def slow_forward(
self,
x: torch.Tensor,
past_key_value: Optional[Lfm2HybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
seqlen = x.shape[1]
x = apply_mask_to_padding_states(x, attention_mask)
BCx = self.in_proj(x).transpose(-1, -2)
B, C, x = BCx.chunk(3, dim=-2)
Bx = B * x
if past_key_value is not None and cache_position[0] > 0:
conv_state = past_key_value.conv_cache[self.layer_idx]
cache_position = cache_position.clamp(0, self.L_cache - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype)
past_key_value.conv_cache[self.layer_idx].copy_(conv_state)
conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1)
if self.bias:
conv_out += self.conv.bias
conv_out = conv_out.unsqueeze(-1)
else:
if past_key_value is not None:
conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
past_key_value.conv_cache[self.layer_idx].copy_(conv_state)
conv_out = self.conv(Bx)[..., :seqlen]
y = C * conv_out
y = y.transpose(-1, -2).contiguous()
y = self.out_proj(y)
return y
def forward(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Lfm2HybridConvCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
if is_fast_path_available and "cuda" in hidden_states.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, past_key_value, cache_position, attention_mask)
return self.slow_forward(hidden_states, past_key_value, cache_position, attention_mask)
class Lfm2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Lfm2Config, layer_idx: int):
super().__init__()
self.is_attention_layer = config.layer_types[layer_idx] == "full_attention"
if self.is_attention_layer:
self.self_attn = Lfm2Attention(config, layer_idx)
else:
self.conv = Lfm2ShortConv(config, layer_idx)
self.feed_forward = Lfm2MLP(config)
self.operator_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
self.ffn_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[tuple[torch.Tensor]] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
if self.is_attention_layer:
hidden_states, _ = self.self_attn(
hidden_states=self.operator_norm(hidden_states),
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
cache_position=cache_position,
**kwargs,
)
else:
hidden_states = self.conv(
hidden_states=self.operator_norm(hidden_states),
past_key_value=past_key_value,
cache_position=cache_position,
attention_mask=attention_mask,
)
hidden_states = hidden_states + residual
hidden_states = hidden_states + self.feed_forward(self.ffn_norm(hidden_states))
return hidden_states
class Lfm2PreTrainedModel(LlamaPreTrainedModel):
_supports_static_cache = False
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Lfm2RMSNorm):
module.weight.data.fill_(1.0)
class Lfm2Model(LlamaModel):
def __init__(self, config: Lfm2Config):
super().__init__(config)
self.pos_emb = Lfm2RotaryEmbedding(config)
self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
del self.norm
del self.rotary_emv
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Lfm2HybridConvCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
batch_size = inputs_embeds.shape[0]
past_key_values = Lfm2HybridConvCache(
config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.pos_emb(hidden_states, position_ids)
# decoder layers
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.embedding_norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
class Lfm2ForCausalLM(LlamaForCausalLM):
pass
__all__ = ["Lfm2ForCausalLM", "Lfm2Model", "Lfm2PreTrainedModel"]

View File

@@ -2503,6 +2503,7 @@ class GenerationTesterMixin:
"xlnet", "xlnet",
"zamba", "zamba",
"zamba2", "zamba2",
"lfm2",
) )
has_standard_cache = not any( has_standard_cache = not any(
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache

View File

View File

@@ -0,0 +1,93 @@
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch LLaMA model."""
import unittest
from transformers import is_torch_available
from transformers.testing_utils import (
require_read_token,
require_torch,
require_torch_accelerator,
slow,
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
if is_torch_available():
from transformers import Lfm2Config, Lfm2ForCausalLM, Lfm2Model
class Lfm2ModelTester(CausalLMModelTester):
if is_torch_available():
config_class = Lfm2Config
base_model_class = Lfm2Model
causal_lm_class = Lfm2ForCausalLM
def __init__(
self,
parent,
layer_types=["full_attention", "conv"],
):
super().__init__(parent)
self.layer_types = layer_types
@require_torch
class Lfm2ModelTest(CausalLMModelTest, unittest.TestCase):
all_model_classes = (Lfm2Model, Lfm2ForCausalLM) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": Lfm2Model,
"text-generation": Lfm2ForCausalLM,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
fx_compatible = False
model_tester_class = Lfm2ModelTester
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = Lfm2ForCausalLM if is_torch_available() else None
@unittest.skip(
"Lfm2 alternates between attention and conv layers, so attention are only returned for attention layers"
)
def test_attention_outputs(self):
pass
@unittest.skip("Lfm2 has a special cache format as it alternates between attention and conv layers")
def test_past_key_values_format(self):
pass
@unittest.skip("Lfm2 has a special cache format which is not compatible with contrastive search")
def test_contrastive_generate(self):
pass
@unittest.skip("Lfm2 has a special cache format which is not compatible with contrastive search")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Lfm2 has a special cache format which is not compatible with contrastive search")
def test_contrastive_generate_low_memory(self):
pass
@require_torch_accelerator
@require_read_token
@slow
class Lfm2IntegrationTest(unittest.TestCase):
pass

View File

@@ -32,6 +32,7 @@ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
SPECIAL_CASES_TO_ALLOW = { SPECIAL_CASES_TO_ALLOW = {
"Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"],
# used internally during generation to provide the custom logit processors with their necessary information # used internally during generation to provide the custom logit processors with their necessary information
"DiaConfig": [ "DiaConfig": [
"delay_pattern", "delay_pattern",