diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a7c79b002b..f7ba50a22d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -697,6 +697,8 @@ title: XLM-V - local: model_doc/xlnet title: XLNet + - local: model_doc/xlstm + title: xLSTM - local: model_doc/yoso title: YOSO - local: model_doc/zamba diff --git a/docs/source/en/model_doc/xlstm.md b/docs/source/en/model_doc/xlstm.md new file mode 100644 index 0000000000..ba47a5a97c --- /dev/null +++ b/docs/source/en/model_doc/xlstm.md @@ -0,0 +1,47 @@ + + + +# xLSTM + +## Overview + +The xLSTM model was proposed in [xLSTM: Extended Long Short-Term Memory](https://openreview.net/forum?id=ARAxPPIAhq) by Maximilian Beck*, Korbinian Pöppel*, Markus Spanring, Andreas Auer, Oleksandra Prudnikova, Michael Kopp, Günter Klambauer, Johannes Brandstetter and Sepp Hochreiter. +xLSTM updates the original LSTM architecture to be competitive with Transformer models by introducing exponential gating, matrix memory expansion, and parallelizable training and ingestion. + +The [7B model](https://hf.co/NX-AI/xLSTM-7b) variant was trained by the xLSTM team Maximilian Beck, Korbinian Pöppel, Phillip Lippe, Richard Kurle, Patrick Blies, Sebastian Böck and Sepp Hochreiter at NXAI. + +The abstract from the paper is the following: + +*In the 1990s, the constant error carousel and gating were introduced as the central ideas of the Long Short-Term Memory (LSTM). Since then, LSTMs have stood the test of time and contributed to numerous deep learning success stories, in particular they constituted the first Large Language Models (LLMs). However, the advent of the Transformer technology with parallelizable self-attention at its core marked the dawn of a new era, outpacing LSTMs at scale. We now raise a simple question: How far do we get in language modeling when scaling LSTMs to billions of parameters, leveraging the latest techniques from modern LLMs, but mitigating known limitations of LSTMs? Firstly, we introduce exponential gating with appropriate normalization and stabilization techniques. Secondly, we modify the LSTM memory structure, obtaining: (i) sLSTM with a scalar memory, a scalar update, and new memory mixing, (ii) mLSTM that is fully parallelizable with a matrix memory and a covariance update rule. Integrating these LSTM extensions into residual block backbones yields xLSTM blocks that are then residually stacked into xLSTM architectures. Exponential gating and modified memory structures boost xLSTM capabilities to perform favorably when compared to state-of-the-art Transformers and State Space Models, both in performance and scaling.* + +This model was contributed by [NX-AI](https://huggingface.co/NX-AI). +The original code can be found [here](https://github.com/NX-AI/xlstm). + + +## xLSTMConfig + +[[autodoc]] xLSTMConfig + +## xLSTMModel + +[[autodoc]] xLSTMModel + - forward + +## xLSTMLMHeadModel + +[[autodoc]] xLSTMForCausalLM + - forward diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index b691cea112..3670833bf6 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -353,6 +353,7 @@ if TYPE_CHECKING: from .xlm_roberta import * from .xlm_roberta_xl import * from .xlnet import * + from .xlstm import * from .xmod import * from .yolos import * from .yoso import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index eb25e0d025..0317832fd7 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -410,6 +410,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str]( ("xlm-roberta", "XLMRobertaConfig"), ("xlm-roberta-xl", "XLMRobertaXLConfig"), ("xlnet", "XLNetConfig"), + ("xlstm", "xLSTMConfig"), ("xmod", "XmodConfig"), ("yolos", "YolosConfig"), ("yoso", "YosoConfig"), @@ -832,6 +833,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str]( ("xlnet", "XLNet"), ("xls_r", "XLS-R"), ("xlsr_wav2vec2", "XLSR-Wav2Vec2"), + ("xlstm", "xLSTM"), ("xmod", "X-MOD"), ("yolos", "YOLOS"), ("yoso", "YOSO"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 85eb8ff6bb..cc779b79a1 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -379,6 +379,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("xlm-roberta", "XLMRobertaModel"), ("xlm-roberta-xl", "XLMRobertaXLModel"), ("xlnet", "XLNetModel"), + ("xlstm", "xLSTMModel"), ("xmod", "XmodModel"), ("yolos", "YolosModel"), ("yoso", "YosoModel"), @@ -474,6 +475,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ("xlm-roberta", "XLMRobertaForMaskedLM"), ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), ("xlnet", "XLNetLMHeadModel"), + ("xlstm", "xLSTMForCausalLM"), ("xmod", "XmodForMaskedLM"), ] ) @@ -692,6 +694,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("xlm-roberta", "XLMRobertaForCausalLM"), ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"), ("xlnet", "XLNetLMHeadModel"), + ("xlstm", "xLSTMForCausalLM"), ("xmod", "XmodForCausalLM"), ("zamba", "ZambaForCausalLM"), ("zamba2", "Zamba2ForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 6e5b07dddf..3747597f3e 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -718,6 +718,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]]( "XLNetTokenizerFast" if is_tokenizers_available() else None, ), ), + ("xlstm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ( "xmod", ( diff --git a/src/transformers/models/xlstm/__init__.py b/src/transformers/models/xlstm/__init__.py new file mode 100644 index 0000000000..00e206973a --- /dev/null +++ b/src/transformers/models/xlstm/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2025 NXAI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from configuration_xlstm import * + from modeling_xlstm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/xlstm/configuration_xlstm.py b/src/transformers/models/xlstm/configuration_xlstm.py new file mode 100644 index 0000000000..80c513adde --- /dev/null +++ b/src/transformers/models/xlstm/configuration_xlstm.py @@ -0,0 +1,302 @@ +# Copyright 2025 NXAI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""xLSTM configuration.""" + +from typing import Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import is_xlstm_available, logging + + +if is_xlstm_available(): + from xlstm.xlstm_large.model import ( + BackendModeType, + ChunkwiseKernelType, + DtypeType, + SequenceKernelType, + StepKernelType, + WeightModeType, + round_up_to_next_multiple_of, + xLSTMLargeConfig, + ) + + external_xlstm = True +else: + from typing import Literal + + BackendModeType = Literal["train", "train_with_padding", "inference"] + ChunkwiseKernelType = Literal[ + "chunkwise--native_autograd", + "parallel--native_autograd", + ] + DtypeType = Literal["float32", "bfloat16", "float16"] + SequenceKernelType = Literal["native_sequence__native"] + StepKernelType = Literal["native"] + WeightModeType = Literal["single", "fused"] + + def round_up_to_next_multiple_of(x: int, multiple_of: int) -> int: + """Rounds up x to the next multiple of multiple_of.""" + return int(((x + multiple_of - 1) // multiple_of) * multiple_of) + + external_xlstm = False + + +logger = logging.get_logger(__name__) + + +class xLSTMConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`xLSTM`]. It is used to instantiate a xLSTM + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the xLSTM-7b [NX-AI/xLSTM-7b](https://huggingface.co/NX-AI/xLSTM-7b) model. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (int, optional, *optional*, defaults to 50304): + Vocabulary size of the xLSTM model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`xLSTMModel`]. Defaults to the GPT2-NeoX tokenizer size. + hidden_size (int, optional, *optional*, defaults to 4096): + Dimensionality of the embeddings or hidden states. + embedding_dim (int, optional, *optional*, defaults to 4096): + Dimensionality of the embeddings or hidden states, use hidde_size if None. + num_hidden_layers (int, optional, *optional*, defaults to 32): + Number of blocks of the xLSTM model. + num_blocks (int, optional, *optional*, defaults to 32): + Number of blocks of the xLSTM model, use num_hidden_layers if None. + num_heads (int, optional, *optional*, defaults to 8): + Number of heads for the xLSTM Layer/Cell. + use_bias (bool, optional, *optional*, defaults to `False`): + Whether to use biases in the xLSTM model. + norm_reduction_force_float32 (bool, optional, *optional*, defaults to `True`): + Whether to force the float32 norm reduction op to be done in fp32 precision. + tie_word_embeddings (bool, optional, *optional*, defaults to `False`): + Whether to tie word embeddings to the lm head weights. + add_out_norm (bool, optional, *optional*, defaults to `True`): + Whether to add an output norm after the blocks before the LMHead. + norm_eps (float, optional, *optional*, defaults to 1e-06): + Norm eps for RMSNorm and Layer Norm. + qk_dim_factor (float, optional, *optional*, defaults to 0.5): + Scale factor for the query and key dimension. + v_dim_factor (float, optional, *optional*, defaults to 1.0): + Scale factor for the value dimension. + chunkwise_kernel (ChunkwiseKernelType, optional, *optional*, defaults to `"chunkwise--native_autograd"`): + Kernel type for chunkwise processing mode. + sequence_kernel (SequenceKernelType, optional, *optional*, defaults to `"native_sequence__native"`): + Kernel type for sequence processing mode. + step_kernel (StepKernelType, optional, *optional*, defaults to `"native"`): + Kernel type for step processing mode. + mode (BackendModeType, optional, *optional*, defaults to `"inference"`): + Operation mode (inference is needed for generation). + chunk_size (int, optional, *optional*, defaults to 64): + Internal chunk size. + return_last_states (bool, optional, *optional*, defaults to `True`): + If to return the last states / cache internally. Needed as True for generation. + autocast_kernel_dtype (DtypeType, optional, *optional*, defaults to `"bfloat16"`): + Kernel dtype for the states. + eps (float, optional, *optional*, defaults to 1e-06): + Epsilon for the mLSTM cell post norm. + inference_state_dtype (DtypeType, optional, *optional*, defaults to `"float32"`): + Kernel dtype for states in inference. + ffn_proj_factor (float, optional, *optional*, defaults to 2.667): + Size factor of the post-up projection gated Feed Forward network. + ffn_round_up_to_multiple_of (int, optional, *optional*, defaults to 64): + Size factor round value of the post-up projection gated Feed Forward network. + gate_soft_cap (float, optional, *optional*, defaults to 15.0): + Gate soft cap scale. + output_logit_soft_cap (float, optional, *optional*, defaults to 30.0): + Output logit soft cap scale. + weight_mode (`Literal`, *optional*, defaults to `"single"`): + Whether parallel linear layers are separated or fused (single). + use_cache (bool, optional, *optional*, defaults to `True`): + Whether to use the cache (xLSTMCache). + pad_token_id (int, optional, *optional*, defaults to 1): + Pad token id needed for generation. + bos_token_id (int, optional, *optional*, defaults to 0): + BOS token id needed for generation. + eos_token_id (int, optional, *optional*, defaults to 2): + EOS token id needed for generation. + max_inference_chunksize (int, optional, *optional*, defaults to 16384): + Limit the chunk size for inference to save memory. + + Example: + + ```python + >>> from transformers import xLSTMConfig, xLSTMModel + + >>> # Initializing a xLSTM configuration + >>> configuration = xLSTMConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = xLSTMModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "xlstm" + + def __init__( + self, + vocab_size: int = 50304, + hidden_size: int = 4096, + embedding_dim: Optional[int] = None, + num_hidden_layers: Optional[int] = 32, + num_blocks: Optional[int] = None, + num_heads: int = 8, + use_bias: bool = False, + norm_reduction_force_float32: bool = True, + tie_word_embeddings: bool = False, + add_out_norm: bool = True, + norm_eps: float = 1e-6, + # mlstm_layer + qk_dim_factor: float = 0.5, + v_dim_factor: float = 1.0, + # mlstm backend + chunkwise_kernel: ChunkwiseKernelType = "chunkwise--native_autograd", + sequence_kernel: SequenceKernelType = "native_sequence__native", + step_kernel: StepKernelType = "native", + # nedded to enable generation + mode: BackendModeType = "inference", + chunk_size: int = 64, + # needed to be true for generation + return_last_states: bool = True, + autocast_kernel_dtype: DtypeType = "bfloat16", + eps: float = 1e-6, + inference_state_dtype: DtypeType = "float32", + # feedforward + ffn_proj_factor: float = 2.667, + ffn_round_up_to_multiple_of: int = 64, + # capping + gate_soft_cap: float = 15.0, + output_logit_soft_cap: float = 30.0, + # weights + weight_mode: WeightModeType = "single", + # HF interface + use_cache: bool = True, + pad_token_id: int = 1, + bos_token_id: int = 0, + eos_token_id: int = 2, + max_inference_chunksize: int = 16384, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size if hidden_size is not None else embedding_dim + self.embedding_dim = embedding_dim if embedding_dim is not None else hidden_size + self.num_hidden_layers = num_hidden_layers if num_hidden_layers is not None else num_blocks + self.num_blocks = num_blocks if num_blocks is not None else num_hidden_layers + self.num_heads = num_heads + self.use_bias = use_bias + self.tie_word_embeddings = tie_word_embeddings + self.add_out_norm = add_out_norm + self.norm_eps = norm_eps + self.norm_reduction_force_float32 = norm_reduction_force_float32 + # mlstm_layer + self.qk_dim_factor = qk_dim_factor + self.v_dim_factor = v_dim_factor + # mlstm backend + self.chunkwise_kernel = chunkwise_kernel + self.sequence_kernel = sequence_kernel + self.step_kernel = step_kernel + self.mode = mode + self.chunk_size = chunk_size + self.return_last_states = return_last_states + self.autocast_kernel_dtype = autocast_kernel_dtype + self.eps = eps + self.inference_state_dtype = inference_state_dtype + # feedforward + self.ffn_proj_factor = ffn_proj_factor + self.ffn_round_up_to_multiple_of = ffn_round_up_to_multiple_of + # capping + self.gate_soft_cap = gate_soft_cap + self.output_logit_soft_cap = output_logit_soft_cap + self.weight_mode = weight_mode + + self.use_cache = use_cache + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.max_inference_chunksize = max_inference_chunksize + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def qk_dim(self): + return round_up_to_next_multiple_of( + self.hidden_size * self.qk_dim_factor, + multiple_of=64, + ) + + @property + def v_dim(self): + return round_up_to_next_multiple_of( + self.hidden_size * self.v_dim_factor, + multiple_of=64, + ) + + @property + def qk_head_dim(self): + return self.qk_dim // self.num_heads + + @property + def v_head_dim(self): + return self.v_dim // self.num_heads + + def to_xlstm_block_config(self): + if external_xlstm: + return xLSTMLargeConfig( + vocab_size=self.vocab_size, + embedding_dim=self.hidden_size, + num_blocks=self.num_hidden_layers, + num_heads=self.num_heads, + use_bias=self.use_bias, + add_out_norm=self.add_out_norm, + norm_eps=self.norm_eps, + norm_reduction_force_float32=self.norm_reduction_force_float32, + # mlstm_layer + qk_dim_factor=self.qk_dim_factor, + v_dim_factor=self.v_dim_factor, + # mlstm backend + chunkwise_kernel=self.chunkwise_kernel, + sequence_kernel=self.sequence_kernel, + step_kernel=self.step_kernel, + mode=self.mode, + chunk_size=self.chunk_size, + return_last_states=self.return_last_states, + autocast_kernel_dtype=self.autocast_kernel_dtype, + eps=self.eps, + inference_state_dtype=self.inference_state_dtype, + # feedforward + ffn_proj_factor=self.ffn_proj_factor, + ffn_round_up_to_multiple_of=self.ffn_round_up_to_multiple_of, + # capping + gate_soft_cap=self.gate_soft_cap, + output_logit_soft_cap=self.output_logit_soft_cap, + weight_mode=self.weight_mode, + ) + else: + return self + + +__all__ = ["xLSTMConfig"] diff --git a/src/transformers/models/xlstm/modeling_xlstm.py b/src/transformers/models/xlstm/modeling_xlstm.py new file mode 100644 index 0000000000..2d2d3a736a --- /dev/null +++ b/src/transformers/models/xlstm/modeling_xlstm.py @@ -0,0 +1,1623 @@ +# Copyright 2025 NXAI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch xLSTM Model.""" + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...generation import GenerationMixin +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_xlstm_available +from .configuration_xlstm import xLSTMConfig + + +if is_xlstm_available(): + from xlstm.xlstm_large.model import mLSTMBlock as xLSTMBlock + from xlstm.xlstm_large.model import mLSTMStateType, soft_cap + from xlstm.xlstm_large.model import xLSTMRMSNorm as xLSTMRMSNorm + + external_xlstm = True +else: + from functools import partial + from typing import Callable, Literal + + from .configuration_xlstm import round_up_to_next_multiple_of + + mLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor, torch.Tensor] + mLSTMStateType = dict[int, mLSTMLayerStateType] + + external_xlstm = False + + def soft_cap(values: torch.Tensor, cap_value: Optional[Union[float, torch.Tensor]] = None) -> torch.Tensor: + """ + Soft caps a tensor to a value. + + Performs a tanh operation on the logits and scales the result to the cap value. Common technique in attention + and output language heads to prevent large logits from dominating the softmax. See for example Gemma2: + https://arxiv.org/abs/2408.00118 + + Args: + values: The tensor to cap. + cap_value: The value to cap the values to. If None, no cap is applied. + + Returns: + The capped values. + """ + if cap_value is None: + return values + return cap_value * torch.tanh(values / cap_value) + + def mlstm_chunkwise_recurrent_fw_C( + matK: torch.Tensor, + matV: torch.Tensor, + vecB: torch.Tensor, + vecI: torch.Tensor, + matC_states: torch.Tensor = None, + vecN_states: torch.Tensor = None, + scaMinter_states: torch.Tensor = None, + matC_initial: torch.Tensor = None, + vecN_initial: torch.Tensor = None, + scaMinter_initial: torch.Tensor = None, + qk_scale: Optional[float] = None, + chunk_size: int = 64, + num_chunks: int = 1, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size, nh, _, dhqk, dhhv = *matK.shape, matV.shape[-1] + nc = num_chunks + _dtype, _device = matK.dtype, matK.device + + if qk_scale is None: + qk_scale = dhqk**-0.5 + + # initialize the states tensors + if matC_states is None: + matC_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk, dhhv), dtype=_dtype, device=_device) + if vecN_states is None: + vecN_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk), dtype=_dtype, device=_device) + if scaMinter_states is None: + scaMinter_states = torch.zeros((batch_size, nh, (nc + 1)), dtype=_dtype, device=_device) + + # assign the initial states to the running states + matC_k = ( + torch.zeros((batch_size, nh, dhqk, dhhv), dtype=_dtype, device=_device) + if matC_initial is None + else matC_initial + ) + vecN_k = ( + torch.zeros((batch_size, nh, dhqk), dtype=_dtype, device=_device) if vecN_initial is None else vecN_initial + ) + scaM_inter_k = ( + torch.zeros((batch_size, nh, 1), dtype=_dtype, device=_device) + if scaMinter_initial is None + else scaMinter_initial + ) + vecA = vecB[..., -1, None] - vecB + vecI + scaG = vecB[..., -1] + scaA_max = vecA.max(-1).values + + scaM_inter_k = scaM_inter_k.squeeze(-1) + + for key in range(0, num_chunks): + # store the states from the previous iteration before updating them + # in the first iteration, these are the initial states + matC_states[:, :, key * dhqk : (key + 1) * dhqk, :] = matC_k + vecN_states[:, :, key * dhqk : (key + 1) * dhqk] = vecN_k + scaMinter_states[:, :, key] = scaM_inter_k + + # m_k update + scaA_max_k = scaA_max[:, :, key] + scaG_k = scaG[:, :, key] + scaM_inter_k_next = torch.max(scaG_k + scaM_inter_k, scaA_max_k) + # C_k update + matK_chunk = matK[:, :, key * chunk_size : (key + 1) * chunk_size, :] # * qk_scale + matV_chunk = matV[:, :, key * chunk_size : (key + 1) * chunk_size, :] + vecA_k = vecA[:, :, key, :] + + vecAbar_k = torch.exp(vecA_k - scaM_inter_k_next[..., None])[:, :, :, None] + + matK_chunk_gated = matK_chunk * vecAbar_k + + scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None] + + # NOTE: no update in-place (i.e. +=) as this gives error for autograd backward + matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk) + + # n_k update + vecN_k_next = scaGbar_k * vecN_k + matK_chunk_gated.transpose(-2, -1).sum(-1) + + # move to the next iteration + scaM_inter_k = scaM_inter_k_next + matC_k = matC_k_next + vecN_k = vecN_k_next + + # store the states from the last iteration + matC_states[:, :, -dhqk:, :] = matC_k + vecN_states[:, :, -dhqk:] = vecN_k + scaMinter_states[:, :, -1] = scaM_inter_k + + return matC_states, vecN_states, scaMinter_states + + def mlstm_chunkwise_parallel_fw_H( + matQ: torch.Tensor, + matK: torch.Tensor, + matV: torch.Tensor, + # these states must be all states up to the last chunk, i.e. :-1 + matC_states: torch.Tensor, + vecN_states: torch.Tensor, + scaMinter_states: torch.Tensor, + vecI: torch.Tensor, + vecB: torch.Tensor, + qk_scale: float, + chunk_size: int = 64, + num_chunks: int = 1, + eps: float = 1e-6, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + _device = matQ.device + nc, chunk_size = num_chunks, chunk_size + batch_size, nh, dqk, dhv = matC_states.shape + matC_k_states = matC_states.view(batch_size, nh, nc, dqk // nc, dhv) + vecN_k_states = vecN_states.view(batch_size, nh, nc, dqk // nc) + scaMinter_k_states = scaMinter_states + + matQ = matQ.view(batch_size, nh, nc, chunk_size, dqk) + matK = matK.view(batch_size, nh, nc, chunk_size, dqk) + matV = matV.view(batch_size, nh, nc, chunk_size, dhv) + + ltr = torch.tril( + torch.ones( + (chunk_size, chunk_size), + dtype=torch.bool, + device=_device, + ) + ) + + # Compute intra chunk contribution: H_intra + matF_logsig_chunk = vecB[:, :, :, :, None] - vecB[:, :, :, None, :] + + matF_logsig_mask_chunk = torch.where(ltr, matF_logsig_chunk, -float("inf")) + + matLogD_chunk = matF_logsig_mask_chunk + vecI[:, :, :, None, :] + + # max_state intra + vecMintra_k = torch.max(matLogD_chunk, dim=-1, keepdim=False).values + + # max_state combined + vecM_b_inter = vecB + scaMinter_k_states[:, :, :, None] + vecM_k_combine = torch.maximum(vecM_b_inter, vecMintra_k) + + vecM_k_combine = vecM_k_combine[:, :, :, :, None] + vecM_b_inter = vecM_b_inter[:, :, :, :, None] + + matLogD_stabilized_chunk = matLogD_chunk - vecM_k_combine + matD_chunk = torch.exp(matLogD_stabilized_chunk) + + matS_chunk = (matQ @ matK.transpose(-2, -1)) * qk_scale + + matM_chunk = matS_chunk * matD_chunk + + # ? Combine H_intra with H_inter + vecBbar = torch.exp(vecM_b_inter - vecM_k_combine) + matQ_chunk_gated = matQ * vecBbar * qk_scale + + matNumerator_common = matQ_chunk_gated @ matC_k_states + matM_chunk @ matV + + vecDenom_l_common = matQ_chunk_gated @ vecN_k_states.unsqueeze(-1) + matM_chunk.sum(dim=-1, keepdim=True) + + vecDenom_max_common = torch.maximum(torch.abs(vecDenom_l_common), torch.exp(-vecM_k_combine)) + + matH_k_chunk = matNumerator_common / (vecDenom_max_common + eps) + + matH_out = matH_k_chunk.view(batch_size, nh, nc * chunk_size, dhv) + + # we need the denominator and the overall max state for the backward pass + vecN_out = vecDenom_max_common.reshape(batch_size, nh, nc * chunk_size) + vecM_out = vecM_k_combine(batch_size, nh, nc * chunk_size) + return matH_out, vecN_out, vecM_out + + def mlstm_chunkwise_fw( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + igate: torch.Tensor, + fgate: torch.Tensor, + cstate: torch.Tensor = None, + nstate: torch.Tensor = None, + mstate: torch.Tensor = None, + qk_scale: Optional[float] = None, + return_last_states: bool = False, + return_all_states: bool = False, + chunk_size: int = 64, + eps: float = 1e-6, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + ]: + batch_size, nh, sequence_length, dhqk = query.shape + if sequence_length % chunk_size != 0: + raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.") + nc = sequence_length // chunk_size + + vecI = igate.view(batch_size, nh, nc, chunk_size) + vecF = fgate.view(batch_size, nh, nc, chunk_size) + + # compute the gates, the g and the a and b vectors + vecF_logsig = fgate.logsigmoid(vecF) + vecB = vecF_logsig.cumsum(-1) + + if qk_scale is None: + qk_scale = dhqk**-0.5 + + #! materialize the C_k, n_k, m_k states for each chunk + matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C( + matK=key, + matV=value, + vecB=vecB, + vecI=vecI, + matC_initial=cstate, + vecN_initial=nstate, + scaMinter_initial=mstate, + qk_scale=qk_scale, + chunk_size=chunk_size, + num_chunks=nc, + ) + + #! compute the outputs within each chunk + matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H( + matQ=query, + matK=key, + matV=value, + matC_states=matC_k_states[:, :, :-dhqk, :], + vecN_states=vecN_k_states[:, :, :-dhqk], + scaMinter_states=scaMinter_k_states[:, :, :-1], + vecI=vecI, + vecB=vecB, + qk_scale=qk_scale, + chunk_size=chunk_size, + num_chunks=nc, + eps=eps, + ) + + ret_tuple = (matH_out, vecN_out, vecM_out) + if return_last_states: + ret_tuple += ( + (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:]), + ) + else: + ret_tuple += (None,) + + if return_all_states: + ret_tuple += ((matC_k_states, vecN_k_states, scaMinter_k_states),) + else: + ret_tuple += (None,) + + return ret_tuple + + def mlstm_chunkwise_native_autograd( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + igate: torch.Tensor, + fgate: torch.Tensor, + c_initial: torch.Tensor = None, + n_initial: torch.Tensor = None, + m_initial: torch.Tensor = None, + return_last_states: bool = False, + eps: float = 1e-6, + chunk_size: int = 64, + **kwargs, + ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: + batch_size, nh, sequence_length, dhqk = query.shape + if sequence_length % chunk_size != 0: + raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.") + nc = sequence_length // chunk_size + + vecI = igate.view(batch_size, nh, nc, chunk_size) + vecF = fgate.view(batch_size, nh, nc, chunk_size) + + # compute the gates, the g and the a and b vectors + vecF_logsig = F.logsigmoid(vecF) + vecB = vecF_logsig.cumsum(-1) + + qk_scale = dhqk**-0.5 + + #! materialize the C_k, n_k, m_k states for each chunk + matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C( + matK=key, + matV=value, + vecB=vecB, + vecI=vecI, + matC_initial=c_initial, + vecN_initial=n_initial, + scaMinter_initial=m_initial, + qk_scale=qk_scale, + chunk_size=chunk_size, + num_chunks=nc, + ) + + #! compute the outputs within each chunk + matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H( + matQ=query, + matK=key, + matV=value, + matC_states=matC_k_states[:, :, :-dhqk, :], + vecN_states=vecN_k_states[:, :, :-dhqk], + scaMinter_states=scaMinter_k_states[:, :, :-1], + vecI=vecI, + vecB=vecB, + qk_scale=qk_scale, + chunk_size=chunk_size, + num_chunks=nc, + eps=eps, + ) + + last_states = (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:]) + + if return_last_states: + return matH_out, last_states + else: + return matH_out + + def mlstm_recurrent_step_native( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + igate: torch.Tensor, + fgate: torch.Tensor, + cstate: torch.Tensor, + nstate: torch.Tensor, + mstate: torch.Tensor, + eps: float = 1e-6, + dtype_state: torch.dtype = torch.float32, + **kwargs, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """This is a single step of the mLSTM operation in recurrent form.""" + dtype_qkv = query.dtype + matC_old = cstate.to(dtype=dtype_state) + vecN_old = nstate.to(dtype=dtype_state) + scaM_old = mstate.to(dtype=dtype_state) + + batch_size, nh, dhqk = query.shape + _, _, dhhv = value.shape + if query.shape != key.shape: + raise ValueError("query and key must have the same shape") + if matC_old.shape != (batch_size, nh, dhqk, dhhv): + raise ValueError(f"matC_old has wrong shape, got {matC_old.shape}") + if vecN_old.shape != (batch_size, nh, dhqk): + raise ValueError(f"vecN_old has wrong shape, got {vecN_old.shape}") + if scaM_old.shape != (batch_size, nh, 1): + raise ValueError(f"scaM_old has wrong shape, got {scaM_old.shape}") + if igate.shape != (batch_size, nh, 1): + raise ValueError(f"scaI has wrong shape, got {igate.shape}") + if fgate.shape != (batch_size, nh, 1): + raise ValueError(f"scaF has wrong shape, got {fgate.shape}") + + # gates + scaF_log = torch.nn.functional.logsigmoid(fgate) + + # update rule + scaM_state_new = torch.max(scaF_log + scaM_old, igate) + + scaF_act = torch.exp(scaF_log + scaM_old - scaM_state_new) + scaI_act = torch.exp(igate - scaM_state_new) + + vecQ_scaled = query * (dhqk ** (-0.5)) + matC_state_new = scaF_act[:, :, :, None] * matC_old + scaI_act[:, :, :, None] * ( + key[:, :, :, None] @ value[:, :, None, :] + ) + vecN_state_new = scaF_act * vecN_old + scaI_act * key + h_num = vecQ_scaled[:, :, None, :] @ matC_state_new.to(dtype=dtype_qkv) + h_num = h_num.squeeze(2).to(dtype=dtype_state) + + qn_dotproduct = vecQ_scaled[:, :, None, :] @ vecN_state_new[:, :, :, None].to(dtype=dtype_qkv) + qn_dotproduct = qn_dotproduct.squeeze(2) + max_val = torch.exp(-scaM_state_new) + h_denom = (torch.maximum(qn_dotproduct.abs(), max_val) + eps).to(dtype=dtype_state) + h = h_num / h_denom + + h = h.to(dtype=dtype_qkv) + matC_state_new = matC_state_new.to(dtype=dtype_state) + vecN_state_new = vecN_state_new.to(dtype=dtype_state) + scaM_state_new = scaM_state_new.to(dtype=dtype_state) + return h, (matC_state_new, vecN_state_new, scaM_state_new) + + def mlstm_recurrent_sequence_native( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + igate: torch.Tensor, + fgate: torch.Tensor, + c_initial: torch.Tensor = None, + n_initial: torch.Tensor = None, + m_initial: torch.Tensor = None, + return_last_states: bool = False, + eps: float = 1e-6, + dtype_state: torch.dtype = torch.float32, + **kwargs, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + ]: + batch_size, nh, sequence_length, dhqk = query.shape + dhv = value.shape[-1] + device = query.device + + if c_initial is not None: + if n_initial is None or m_initial is None: + raise ValueError("Initial states must be provided together.") + if n_initial is None or m_initial is None: + raise ValueError("Initial states must be provided together.") + matC_state, vecN_state, vecM_state = ( + c_initial.to(dtype=dtype_state), + n_initial.to(dtype=dtype_state), + m_initial.to(dtype=dtype_state), + ) + else: + # memory state + matC_state = torch.zeros((batch_size, nh, dhqk, dhv), dtype=dtype_state, device=device) + # normalizer state + vecN_state = torch.zeros((batch_size, nh, dhqk), dtype=dtype_state, device=device) + # max state + vecM_state = torch.zeros((batch_size, nh, 1), dtype=dtype_state, device=device) + + vecH_list = [] + for t in range(sequence_length): + # gates + vecF_t, vecI_t = fgate[:, :, t, None], igate[:, :, t, None] + + # projections + vecQ_t, vecK_t, vecV_t = query[:, :, t, :], key[:, :, t, :], value[:, :, t, :] + + # step + vecH, (matC_state, vecN_state, vecM_state) = mlstm_recurrent_step_native( + cstate=matC_state, + nstate=vecN_state, + mstate=vecM_state, + query=vecQ_t, + key=vecK_t, + value=vecV_t, + igate=vecI_t, + fgate=vecF_t, + eps=eps, + dtype_state=dtype_state, + **kwargs, + ) + vecH_list.append(vecH) + + matH = torch.stack(vecH_list, dim=-2) + + if return_last_states: + return matH, (matC_state, vecN_state, vecM_state) + else: + return matH + + def wrap_chunkwise_pad_zeros( + mlstm_chunkwise_kernel: Callable, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + fgate: torch.Tensor, + igate: torch.Tensor, + c_initial: torch.Tensor = None, + n_initial: torch.Tensor = None, + m_initial: torch.Tensor = None, + return_last_states: bool = False, + eps: float = 1e-6, + autocast_kernel_dtype: torch.dtype = torch.bfloat16, + chunk_size: int = 64, + **kwargs, + ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: + if return_last_states: + raise ValueError( + "We are padding zeros, so we cannot return last states,", + "as they would be not the true last states.", + ) + + batch_size, nh, sequence_length, dhqk = query.shape + S_unpadded = sequence_length + # padding to chunk size for kernels + if sequence_length % chunk_size != 0: + S_padded = ((sequence_length + chunk_size - 1) // chunk_size) * chunk_size + q_pad = query.new_zeros(batch_size, nh, S_padded, query.shape[3]) + k_pad = key.new_zeros(batch_size, nh, S_padded, key.shape[3]) + v_pad = value.new_zeros(batch_size, nh, S_padded, value.shape[3]) + i_pad = igate.new_zeros(batch_size, nh, S_padded) + f_pad = fgate.new_zeros(batch_size, nh, S_padded) + q_pad[:, :, :S_unpadded, :] = query + k_pad[:, :, :S_unpadded, :] = key + v_pad[:, :, :S_unpadded, :] = value + i_pad[:, :, :S_unpadded] = igate + f_pad[:, :, :S_unpadded] = fgate + else: + q_pad = query + k_pad = key + v_pad = value + i_pad = igate + f_pad = fgate + + matH = mlstm_chunkwise_kernel( + query=q_pad, + key=k_pad, + value=v_pad, + igate=i_pad, + fgate=f_pad, + c_initial=c_initial, + n_initial=n_initial, + m_initial=m_initial, + return_last_states=return_last_states, + eps=eps, + autocast_kernel_dtype=autocast_kernel_dtype, + chunk_size=chunk_size, + **kwargs, + ) + matH = matH[:, :, :S_unpadded, :] + return matH + + def wrap_chunkwise_arbitrary_sequence_length( + mlstm_chunkwise_kernel: Callable, + mlstm_sequence_kernel: Callable, + mlstm_step_kernel: Callable, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + fgate: torch.Tensor, + igate: torch.Tensor, + c_initial: torch.Tensor = None, + n_initial: torch.Tensor = None, + m_initial: torch.Tensor = None, + return_last_states: bool = True, + eps: float = 1e-6, + autocast_kernel_dtype: torch.dtype = torch.bfloat16, + chunk_size: int = 64, + enable_logging: bool = False, + ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: + """This function computes the last hidden state and matH outputs of the mLSTM, independently of the sequence length. + + For this it uses three kernels: + - mlstm_chunkwise_kernel: mlstm chunkwise kernels that processes chunks of a given chunk size in parallel. + - mlstm_sequence_kernel: mlstm kernel that processes the remaining sequence length in a single step recurrence. + - mlstm_step_kernel: mlstm kernel that processes a sequence length of 1 in a single step. + + It tries to maximize the chunksizes to improve performance. + It will start with the given chunk size and then divides the chunksize by 2 until the chunk size is smaller than 16. + At every chunksize it will process the maximal number of chunks that fit into the remaining sequence length. + + E.g. for chunk_size = 64, this function will try the chunksizes [64, 32, 16] if necessary. + + For the remaining sequence length, which is smaller than 16, we use a different kernel that computes the mLSTM + in a single step and loop over this in pytorch. + + Args: + mlstm_chunkwise_kernel: The mLSTM chunkwise kernel that processes chunks of a given chunk size in parallel + mlstm_sequence_kernel: The mLSTM kernel that processes the remaining sequence length in a single step recurrence + query: The query tensor (batch_size, nh, sequence_length, dhqk) + key: The key tensor (batch_size, nh, sequence_length, dhqk) + value: The value tensor (batch_size, nh, sequence_length, dhhv) + fgate: The forget gate tensor (batch_size, nh, sequence_length) + igate: The input gate tensor (batch_size, nh, sequence_length) + c_initial: The initial cell state tensor (batch_size, nh, dhqk, dhhv) + n_initial: The initial hidden state tensor (batch_size, nh, dhqk) + m_initial: The initial memory state tensor (batch_size, nh, 1) + return_last_states: If True, the function will return the last states of the mLSTM + eps: The epsilon value used for numerical stability + autocast_kernel_dtype: The dtype used for the kernel computation + chunk_size: The chunk size used for the chunkwise kernel + enable_logging: If True, the function will log debug information. Default is False. + + Returns: + The last hidden state tensor (batch_size, nh, sequence_length, dhhv) or a tuple containing the last hidden state tensor and the last states of the mLSTM + Last states are (cstate (batch_size, nh, dhqk, dhhv), nstate (batch_size, nh, dhqk), mstate (batch_size, nh, 1)). + """ + + batch_size, nh, sequence_length, dhqk = key.shape + dhhv = value.shape[-1] + + c_state = ( + c_initial + if c_initial is not None + else torch.zeros(batch_size, nh, dhqk, dhhv, device=key.device, dtype=torch.float32) + ) + n_state = ( + n_initial + if n_initial is not None + else torch.zeros(batch_size, nh, dhqk, device=key.device, dtype=torch.float32) + ) + m_state = ( + m_initial + if m_initial is not None + else torch.zeros(batch_size, nh, 1, device=key.device, dtype=torch.float32) + ) + + if sequence_length > 1: + # process the sequence length in chunks + h_outs = [] + seq_len_start_idx = 0 + remaining_seq_len = sequence_length - seq_len_start_idx + num_chunks = remaining_seq_len // chunk_size + if num_chunks > 0: + iter_seq_len = chunk_size * num_chunks + seq_len_idx = seq_len_start_idx + iter_seq_len + h_out, (c_state, n_state, m_state) = mlstm_chunkwise_kernel( + query=query[..., seq_len_start_idx:seq_len_idx, :].contiguous(), + key=key[..., seq_len_start_idx:seq_len_idx, :].contiguous(), + value=value[..., seq_len_start_idx:seq_len_idx, :].contiguous(), + fgate=fgate[..., seq_len_start_idx:seq_len_idx].contiguous(), + igate=igate[..., seq_len_start_idx:seq_len_idx].contiguous(), + c_initial=c_state, + n_initial=n_state, + m_initial=m_state, + chunk_size=chunk_size, + return_last_states=True, + autocast_kernel_dtype=autocast_kernel_dtype, + eps=eps, + ) + seq_len_start_idx += iter_seq_len + h_outs.append(h_out) + + remaining_seq_len = sequence_length - seq_len_start_idx + + if remaining_seq_len > 0: + # we use here matK as query as this kernel does not need a query, since we do not care about the outputs only about the last state + h_out, (c_state, n_state, m_state) = mlstm_sequence_kernel( + query=query[..., seq_len_start_idx:sequence_length, :].contiguous(), + key=key[..., seq_len_start_idx:sequence_length, :].contiguous(), + value=value[..., seq_len_start_idx:sequence_length, :].contiguous(), + igate=igate[..., seq_len_start_idx:sequence_length].contiguous(), + fgate=fgate[..., seq_len_start_idx:sequence_length].contiguous(), + c_initial=c_state, + n_initial=n_state, + m_initial=m_state, + return_last_states=True, + eps=eps, + ) + h_outs.append(h_out) + h_out = torch.concatenate(h_outs, dim=2) + + else: + if sequence_length != 1: + raise ValueError( + f"Received empty sequence (sequence_length={sequence_length}), require at least single element in the sequence." + ) + # process the sequence length in a single step + # while this case is also captured by the regular mode above, + # it avoids the overhead of the loop and calls the step kernel directly + # The step function does not want a sequence dimension + # qkv shape is (batch_size, nh, dhqk/dhv) + # igate, fgate shape is (batch_size, nh, 1) + h_out, (c_state, n_state, m_state) = mlstm_step_kernel( + query=query.squeeze(2), + key=key.squeeze(2), + value=value.squeeze(2), + igate=igate, + fgate=fgate, + cstate=c_state, + nstate=n_state, + mstate=m_state, + eps=eps, + ) + h_out = h_out[:, :, None, :] + + if return_last_states: + return h_out, (c_state, n_state, m_state) + else: + return h_out + + class xLSTMBackend(nn.Module): + """xLSTM Backend Module for PyTorch. + + This module wraps the xLSTM kernels and provides a high-level interface for training and inference. + """ + + config_class = xLSTMConfig + + def __init__(self, config: xLSTMConfig): + super().__init__() + self.config = config + self.chunkwise_kernel_fn = mlstm_chunkwise_native_autograd + self.sequence_kernel_fn = mlstm_recurrent_sequence_native + self.step_kernel_fn = mlstm_recurrent_step_native + + self._inference_fn = partial( + wrap_chunkwise_arbitrary_sequence_length, + mlstm_chunkwise_kernel=self.chunkwise_kernel_fn, + mlstm_sequence_kernel=partial( + self.sequence_kernel_fn, + dtype_state=getattr(torch, config.inference_state_dtype), + ), + mlstm_step_kernel=partial( + self.step_kernel_fn, + dtype_state=getattr(torch, config.inference_state_dtype), + ), + chunk_size=config.chunk_size, + eps=config.eps, + autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype), + return_last_states=True, + ) + + train_kernel_fn = partial( + self.chunkwise_kernel_fn, + autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype), + eps=config.eps, + chunk_size=config.chunk_size, + ) + if "with_padding" in config.mode: + train_kernel_fn = partial(wrap_chunkwise_pad_zeros, mlstm_chunkwise_kernel=train_kernel_fn) + self._train_fn = train_kernel_fn + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + igate: torch.Tensor, + fgate: torch.Tensor, + c_initial: torch.Tensor = None, + n_initial: torch.Tensor = None, + m_initial: torch.Tensor = None, + return_last_states: bool = False, + mode: Optional[Literal["train", "inference"]] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: + """Forward pass of the mLSTM backend. + + Depending on the configured mode, this method will call the appropriate kernel function. + + Args: + query: The query tensor of shape (batch_size, nh, sequence_length, dhqk). + key: The key tensor of shape (batch_size, nh, sequence_length, dhqk). + value: The value tensor of shape (batch_size, nh, sequence_length, dhhv). + igate: The input gate preactivation tensor of shape (batch_size, nh, sequence_length). + fgate: The forget gate preactivation tensor of shape (batch_size, nh, sequence_length). + c_initial: The initial cell state tensor of shape (batch_size, nh, dhqk, dhhv). + Defaults to None. + n_initial: The initial hidden state tensor of shape (batch_size, nh, dhqk). Defaults to None. + m_initial: The initial memory tensor of shape (batch_size, nh, 1). Defaults to None. + return_last_states: Whether to return the last states of the sequence. Defaults to None. + If None, the value from the config is used. + + Returns: + hidden states of shape (batch_size, nh, sequence_length, dhhv) + hidden states and last states the last states are the cell state cstate (batch_size, nh, dhqk, dhhv), + the normalizer state nstate (batch_size, nh, dhqk), and the max state mstate (batch_size, nh, 1) + """ + if mode is None: + mode = self.config.mode + + if "train" in mode: + if return_last_states is None: + return_last_states = self.config.return_last_states + + if self.config.mode == "train_with_padding": + if return_last_states: + raise ValueError("return_last_states=True is not supported with train_with_padding mode.") + + return self._train_fn( + query=query, + key=key, + value=value, + igate=igate, + fgate=fgate, + c_initial=c_initial, + n_initial=n_initial, + m_initial=m_initial, + return_last_states=return_last_states, + ) + + elif "inference" in mode: + # inference mode always returns the last states + return self._inference_fn( + query=query, + key=key, + value=value, + igate=igate, + fgate=fgate, + c_initial=c_initial, + n_initial=n_initial, + m_initial=m_initial, + ) + else: + raise ValueError(f"Unknown mode: {self.config.mode}") + + def extra_repr(self) -> str: + return f"{self.config}" + + class xLSTMRMSNorm(nn.Module): + """Root mean square normalization layer implementation similar + to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html. + + It normalizes the input tensor by the root mean square of the last dimension. + + Args: + num_features: The number of features in the input tensor. + eps: A small value to avoid division by zero. + use_weight: Whether to use a learnable weight. + use_bias: Whether to use a learnable bias. + force_float32_reductions: Whether to force float32 reductions. + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-6, + use_weight: bool = True, + use_bias: bool = False, + force_float32_reductions: bool = True, + ): + super().__init__() + self.num_features = num_features + self.eps = eps + self.force_float32_reductions = force_float32_reductions + + if use_weight: + self.weight = nn.Parameter(torch.ones(num_features)) + else: + self.weight = None + + if use_bias: + self.bias = nn.Parameter(torch.zeros(num_features)) + else: + self.bias = None + + def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor: + if self.weight is not None: + x = x * self.weight + if self.bias is not None: + x = x + self.bias + return x + + def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor: + # apply rms norm over the last dimension, i.e. HD dimension + in_dtype = x.dtype + if self.force_float32_reductions: + x = x.float() + x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + return x.to(in_dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._rms_normalize(x) + x = self._apply_weight_bias(x) + return x + + class xLSTMMultiHeadLayerNorm(nn.Module): + """Multi-head version of the LayerNorm layer. + + It normalizes the last dimension of the input tensor. + + The input is assumed to have the shape (batch_size, sequence_length, nh, DH), where: + batch_size: batch size + sequence_length: sequence length + nh: number of heads + DH: head dimension + + The normalization is applied over the last dimension (DH) of the input tensor. + + Args: + num_heads: The number of heads. + head_dim: The head dimension. + eps: A small value to avoid division by zero. + use_weight: Whether to use a learnable weight. + use_bias: Whether to use a learnable bias. + force_float32_reductions: Whether to force float32 reductions + + Returns: + The normalized tensor with the shape (batch_size, sequence_length, nh * DH). + """ + + def __init__( + self, + num_heads: int, + head_dim: int, + eps: float = 1e-6, + use_weight: bool = True, + use_bias: bool = False, + force_float32_reductions: bool = True, + ): + super().__init__() + self.num_features = num_heads * head_dim + self.eps = eps + self.force_float32_reductions = force_float32_reductions + + if use_weight: + self.weight = nn.Parameter(torch.ones(self.num_features)) + else: + self.weight = None + + if use_bias: + self.bias = nn.Parameter(torch.zeros(self.num_features)) + else: + self.bias = None + self.num_heads = num_heads + self.head_dim = head_dim + + def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor: + if self.weight is not None: + x = x * self.weight + if self.bias is not None: + x = x + self.bias + return x + + def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor: + # apply layer norm over the last dimension, i.e. HD dimension + in_dtype = x.dtype + if self.force_float32_reductions: + x = x.float() + x_centered = x - x.mean(dim=-1, keepdim=True) + y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps) + return y.to(in_dtype) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + batch_size, sequence_length, nh, DH = x.shape + if nh != self.num_heads: + raise ValueError(f"Expected {self.num_heads} heads, got {nh}, input shape: {x.shape}") + if DH != self.head_dim: + raise ValueError(f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}") + + x = self._layer_normalize(x) + x = x.reshape(batch_size, sequence_length, -1) + x = self._apply_weight_bias(x) + return x + + class xLSTMFeedForward(nn.Module): + def __init__(self, config: xLSTMConfig): + super().__init__() + self.config = config + + self.up_proj_dim = round_up_to_next_multiple_of( + config.hidden_size * config.ffn_proj_factor, + config.ffn_round_up_to_multiple_of, + ) + + if self.config.weight_mode == "single": + self.proj_up_gate = nn.Linear( + in_features=config.hidden_size, + out_features=self.up_proj_dim, + bias=self.config.use_bias, + ) + self.proj_up = nn.Linear( + in_features=config.hidden_size, + out_features=self.up_proj_dim, + bias=self.config.use_bias, + ) + elif self.config.weight_mode == "fused": + self.proj_up_gate_z = nn.Linear( + in_features=config.hidden_size, + out_features=2 * self.up_proj_dim, + bias=self.config.use_bias, + ) + + self.proj_down = nn.Linear( + in_features=self.up_proj_dim, + out_features=config.hidden_size, + bias=self.config.use_bias, + ) + + self.act_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.config.weight_mode == "single": + x = self.act_fn(self.proj_up_gate(x)) * self.proj_up(x) + elif self.config.weight_mode == "fused": + x = self.proj_up_gate_z(x) + gate, z = torch.tensor_split(x, (self.up_proj_dim,), dim=-1) + x = self.act_fn(gate) * z + + y = self.proj_down(x) + return y + + class xLSTMLayer(nn.Module): + def __init__(self, config: xLSTMConfig): + super().__init__() + self.config = config + + self.v_dim = int(config.hidden_size * config.v_dim_factor) + self.qk_dim = int(config.hidden_size * config.qk_dim_factor) + + if self.config.weight_mode == "single": + self.query = nn.Linear( + in_features=self.config.hidden_size, + out_features=self.qk_dim, + bias=self.config.use_bias, + ) + self.key = nn.Linear( + in_features=self.config.hidden_size, + out_features=self.qk_dim, + bias=self.config.use_bias, + ) + self.value = nn.Linear( + in_features=self.config.hidden_size, + out_features=self.v_dim, + bias=self.config.use_bias, + ) + + self.ogate_preact = nn.Linear( + in_features=self.config.hidden_size, + out_features=self.v_dim, + bias=self.config.use_bias, + ) + self.igate_preact = nn.Linear( + in_features=self.config.hidden_size, + out_features=self.config.num_heads, + bias=True, + ) + self.fgate_preact = nn.Linear( + in_features=self.config.hidden_size, + out_features=self.config.num_heads, + bias=True, + ) + elif self.config.weight_mode == "fused": + self.qkv_opreact = nn.Linear( + in_features=self.config.hidden_size, + out_features=2 * self.qk_dim + 2 * self.v_dim, + bias=self.config.use_bias, + ) + self.ifgate_preact = nn.Linear( + in_features=self.config.hidden_size, + out_features=2 * self.config.num_heads, + bias=True, + ) + + self.ogate_act_fn = nn.Sigmoid() + self.mlstm_backend = xLSTMBackend(config=self.config) + + self.multihead_norm = xLSTMMultiHeadLayerNorm( + num_heads=self.config.num_heads, + head_dim=self.v_dim // self.config.num_heads, + eps=self.config.norm_eps, + use_weight=True, + use_bias=self.config.use_bias, + force_float32_reductions=self.config.norm_reduction_force_float32, + ) + self.out_proj = nn.Linear( + in_features=self.v_dim, + out_features=self.config.hidden_size, + bias=self.config.use_bias, + ) + + def forward( + self, x: torch.Tensor, state: Optional[mLSTMLayerStateType] = None + ) -> tuple[torch.Tensor, Optional[mLSTMLayerStateType]]: + if x.ndim != 3: + raise ValueError(f"Input must have shape [batch_size, sequence_length, HD], got {x.shape}") + batch_size, sequence_length, _ = x.shape + if self.config.weight_mode == "single": + query = self.query(x) + key = self.key(x) + value = self.value(x) + o_preact = self.ogate_preact(x) + i_preact = soft_cap(self.igate_preact(x), cap_value=self.config.gate_soft_cap) + f_preact = soft_cap(self.fgate_preact(x), cap_value=self.config.gate_soft_cap) + + elif self.config.weight_mode == "fused": + qkv_opreact = self.qkv_opreact(x) + query, key, value, o_preact = torch.tensor_split( + qkv_opreact, + ( + self.qk_dim, + 2 * self.qk_dim, + 2 * self.qk_dim + self.v_dim, + ), + dim=-1, + ) + + if_preact = soft_cap(self.ifgate_preact(x), cap_value=self.config.gate_soft_cap) + i_preact, f_preact = torch.tensor_split(if_preact, (self.config.num_heads,), dim=-1) + + query = query.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2) + key = key.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2) + value = value.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2) + i_preact = i_preact.transpose(1, 2) + f_preact = f_preact.transpose(1, 2) + if state is None: + c_initial, n_initial, m_initial = None, None, None + else: + c_initial, n_initial, m_initial = state + + h, state = self.mlstm_backend( + query=query, + key=key, + value=value, + igate=i_preact, + fgate=f_preact, + c_initial=c_initial, + n_initial=n_initial, + m_initial=m_initial, + ) + expected_h_shape = ( + batch_size, + self.config.num_heads, + sequence_length, + self.v_dim // self.config.num_heads, + ) + if h.shape != expected_h_shape: + raise ValueError(f"Got {h.shape}, expected {expected_h_shape}") + + h = h.transpose(1, 2) + h_norm = self.multihead_norm(h) + h_norm = h_norm.reshape(batch_size, sequence_length, -1) + + h_out = self.ogate_act_fn(o_preact) * h_norm + + y = self.out_proj(h_out) + return y, state + + class xLSTMBlock(nn.Module): + def __init__(self, config: xLSTMConfig): + super().__init__() + self.config = config + self.norm_mlstm = xLSTMRMSNorm( + num_features=config.hidden_size, + eps=config.norm_eps, + use_weight=True, + use_bias=config.use_bias, + force_float32_reductions=config.norm_reduction_force_float32, + ) + self.mlstm_layer = xLSTMLayer(config) + self.norm_ffn = xLSTMRMSNorm( + num_features=config.hidden_size, + eps=config.norm_eps, + use_weight=True, + use_bias=config.use_bias, + force_float32_reductions=config.norm_reduction_force_float32, + ) + self.ffn = xLSTMFeedForward(config) + + def forward( + self, x: torch.Tensor, state: Optional[mLSTMStateType] = None + ) -> tuple[torch.Tensor, mLSTMStateType]: + x_mlstm = self.norm_mlstm(x) + x_mlstm, state = self.mlstm_layer(x_mlstm, state) + x = x + x_mlstm + + x_ffn = self.norm_ffn(x) + x_ffn = self.ffn(x_ffn) + x = x + x_ffn + + return x, state + + +def small_init_method(dim): + """ + Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py + Fills the input Tensor with values according to the method described in Transformers without Tears: Improving + the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution.""" + std = (2 / (5 * dim)) ** (1 / 2) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +def wang_init_method(n_layers, dim): + """ + Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py + """ + std = 2 / n_layers / dim ** (1 / 2) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +class xLSTMPreTrainedModel(PreTrainedModel): + """ + An abstract class for an interface to loading a pre-trained xLSTM model. + """ + + config_class = xLSTMConfig + base_model_prefix = "backbone" + _no_split_modules = ["xLSTMBlock"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _module_name_map(self, module): + for name, mod in self.named_modules(): + if mod is module: + return name + return "" + + def _init_weights(self, module): + if isinstance(module, nn.Embedding): + small_init_method(self.config.hidden_size)(self.embeddings.weight) + elif isinstance(module, nn.Linear): + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + if self.config.weight_mode == "single" and "gate" in self._module_name_map(module): + torch.nn.init.zeros_(module.weight) + with torch.no_grad(): + if "igate" in self._module_name_map(module): + module.bias.copy_(-10.0 * torch.ones_like(module.bias)) + elif "fgate" in self._module_name_map(module): + module.bias.copy_( + torch.linspace( + 3.0, + 6.0, + module.bias.shape[-1], + ).to( + device=module.bias.device, + dtype=module.bias.dtype, + ) + ) + elif self.config.weight_mode == "fused" and "gate" in self._module_name_map(module): + torch.nn.init.zeros_(module.weight) + with torch.no_grad(): + module.bias[: self.config.num_heads] += -module.bias[ + : self.config.num_heads + ] - 10.0 * torch.ones_like(module.bias) + module.bias[: self.config.num_heads] += -module.bias[self.config.num_heads :] + torch.linspace( + 3.0, + 6.0, + module.bias.shape[-1], + ).to( + device=module.bias.device, + dtype=module.bias.dtype, + ) + elif "proj_down" in self._module_name_map(module): + wang_init_method(dim=module.weight.shape[1], n_layers=self.config.num_hidden_layers)(module.weight) + elif "out_proj" in self._module_name_map(module): + wang_init_method(dim=self.config.hidden_size, n_layers=self.config.num_hidden_layers)(module.weight) + elif module.weight is not None: + small_init_method(self.config.hidden_size)(module.weight) + elif isinstance(module, xLSTMRMSNorm) or hasattr(module, "_layer_normalize"): + torch.nn.init.ones_(module.weight) + if hasattr(module, "bias") and module.bias is not None: + torch.nn.init.zeros_(module.bias) + + +class xLSTMCache: + """ + Cache for xLSTM model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The batch size with which the model will be used. + dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Attributes: + seqlen_offset: int + dtype: torch.dtype + + Example: + + ```python + >>> from transformers import AutoTokenizer, xLSTMForCausalLM, xLSTMCache + + >>> model = xLSTMForCausalLM.from_pretrained("NX-AI/xLSTM-7b") + >>> tokenizer = xLSTMTokenizer.from_pretrained("NX-AI/xLSTM-7b") + + >>> inputs = tokenizer(text="I am an xLSTM", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_params = xLSTMCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, cache_params=cache_params, use_cache=True) + >>> outputs.cache_params + xLSTMCache() + """ + + def __init__( + self, + config: xLSTMConfig, + max_batch_size: int, + dtype: torch.dtype = torch.bfloat16, + device: Optional[str] = None, + **kwargs, + ): + self.seqlen_offset = 0 + self.dtype = dtype + self.config = config + self.rnn_state = { + layer: ( + torch.zeros( + [max_batch_size, config.num_heads, config.qk_head_dim, config.v_head_dim], + dtype=dtype, + device=device, + ), + torch.zeros([max_batch_size, config.num_heads, config.qk_head_dim], dtype=dtype, device=device), + torch.zeros([max_batch_size, config.num_heads, 1], dtype=dtype, device=device), + ) + for layer in range(config.num_hidden_layers) + } + + def reset(self): + self.rnn_state = { + layer: ( + torch.zeros_like(self.rnn_state[layer][0]), + torch.zeros_like(self.rnn_state[layer][1]), + torch.zeros_like(self.rnn_state[layer][2]), + ) + for layer in self.rnn_state + } + + +@dataclass +@auto_docstring +class xLSTMOutput(ModelOutput): + r""" + cache_params (`xLSTMCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + """ + + last_hidden_state: Optional[torch.FloatTensor] + cache_params: Optional[xLSTMCache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + + +@auto_docstring +class xLSTMModel(xLSTMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + # use embbeding_dim and num_blocks once here to make use of them + self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim) + self.blocks = nn.ModuleList([xLSTMBlock(config) for _ in range(config.num_blocks)]) + self.out_norm = xLSTMRMSNorm(config.hidden_size, eps=config.norm_eps) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embedding): + self.embeddings = new_embedding + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + cache_params: Optional[xLSTMCache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> Union[tuple, xLSTMOutput]: + r""" + cache_params (`xLSTMCache`, *optional*): + The xLSTMCache that carries the RNN states. + """ + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if use_cache and cache_params is None: + cache_params = xLSTMCache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + + hidden_states = inputs_embeds + + if ( + not self.training + and self.config.max_inference_chunksize < hidden_states.shape[1] + and not output_hidden_states + ): + offset = 0 + with torch.no_grad(): + if cache_params is None: + cache_params = xLSTMCache(config=self.config, batch_size=hidden_states.shape[0]) + final_state = torch.zeros_like(hidden_states) + while offset < hidden_states.shape[1]: + hidden_states_chunk = hidden_states[ + :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1]) + ] + for layer_idx, xlstm_block in enumerate(self.blocks): + hidden_states_chunk, rnn_state = xlstm_block( + hidden_states_chunk, + state=cache_params.rnn_state[layer_idx], + ) + for state_idx in range(len(cache_params.rnn_state[layer_idx])): + local_rnn_state = rnn_state[state_idx] + cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state) + cache_params.rnn_state_initial = False + final_state[ + :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1]) + ] = hidden_states_chunk + offset += self.config.max_inference_chunksize + hidden_states = final_state + else: + all_hidden_states = () if output_hidden_states else None + for layer_idx, xlstm_block in enumerate(self.blocks): + if self.gradient_checkpointing and self.training: + hidden_states, rnn_state = self._gradient_checkpointing_func( + xlstm_block.__call__, + hidden_states, + cache_params.rnn_state[layer_idx] if cache_params is not None else None, + ) + else: + hidden_states, rnn_state = xlstm_block( + hidden_states, + state=cache_params.rnn_state[layer_idx] if cache_params is not None else None, + ) + if cache_params: + for state_idx in range(len(cache_params.rnn_state[layer_idx])): + local_rnn_state = rnn_state[state_idx] + cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state) + cache_params.rnn_state_initial = False + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if use_cache: + cache_params.seqlen_offset += inputs_embeds.shape[1] + + hidden_states = self.out_norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return xLSTMOutput( + last_hidden_state=hidden_states, + cache_params=cache_params, + hidden_states=all_hidden_states, + ) + + +@dataclass +@auto_docstring +class xLSTMCausalLMOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`xLSTMCache`, *optional*, carrying the RNN states): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[xLSTMCache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + + +@auto_docstring +class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin): + def __init__(self, config): + super().__init__(config) + self.backbone = xLSTMModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def prepare_inputs_for_generation( + self, + input_ids, + inputs_embeds=None, + use_cache=None, + cache_params: Optional[xLSTMCache] = None, + **kwargs, + ): + if use_cache and cache_params is not None: + # If the first cache position is non-zero, we assume we are in generation mode. + # Thus, the cache_params state is assumed to be the state before the last token + # (lastly generated token), and all previous tokens are already ingested. + # This should as well support generation from scratch with the [BOS] token inserted first. + input_ids = input_ids[:, -1:] + if inputs_embeds is not None: + inputs_embeds = inputs_embeds[:, -1:] + + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update({"cache_params": cache_params, "use_cache": use_cache}) + return model_inputs + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[xLSTMCache] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> Union[tuple, xLSTMCausalLMOutput]: + r""" + cache_params (`xLSTMCache`, *optional*): + The xLSTMCache that carries the RNN states. + """ + xlstm_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + **kwargs, + ) + hidden_states = xlstm_outputs[0] + + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + + if not self.training and self.config.max_inference_chunksize < logits.shape[1]: + offset = 0 + with torch.no_grad(): + while offset < logits.shape[1]: + logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])] = soft_cap( + logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])], + self.config.output_logit_soft_cap, + ) + offset += self.config.max_inference_chunksize + else: + logits = soft_cap(logits, self.config.output_logit_soft_cap) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < nstate predict nstate + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + return xLSTMCausalLMOutput( + loss=loss, + logits=logits, + cache_params=xlstm_outputs.cache_params, + hidden_states=xlstm_outputs.hidden_states, + ) + + +__all__ = [ + "xLSTMForCausalLM", + "xLSTMModel", + "xLSTMPreTrainedModel", +] diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 0bb3709a42..c45a93e406 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -270,6 +270,7 @@ from .import_utils import ( is_uroman_available, is_vision_available, is_vptq_available, + is_xlstm_available, is_yt_dlp_available, requires_backends, torch_only_method, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 310b00eb73..fb9d56e160 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -588,6 +588,12 @@ def is_causal_conv1d_available(): return False +def is_xlstm_available(): + if is_torch_available(): + return _is_package_available("xlstm") + return False + + def is_mambapy_available(): if is_torch_available(): return _is_package_available("mambapy") diff --git a/tests/models/xlstm/__init__.py b/tests/models/xlstm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/xlstm/test_modeling_xlstm.py b/tests/models/xlstm/test_modeling_xlstm.py new file mode 100644 index 0000000000..3ad5f67100 --- /dev/null +++ b/tests/models/xlstm/test_modeling_xlstm.py @@ -0,0 +1,371 @@ +# Copyright 2025 NXAI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from parameterized import parameterized + +from transformers import AutoTokenizer, is_torch_available, xLSTMConfig +from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + xLSTMForCausalLM, + xLSTMModel, + ) + from transformers.models.xlstm.modeling_xlstm import xLSTMBlock, xLSTMCache + from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2 +else: + is_torch_greater_or_equal_than_2_2 = False + + +class xLSTMModelTester: + def __init__( + self, + parent, + batch_size=13, + num_heads=2, + seq_length=7, + is_training=True, + use_labels=True, + vocab_size=99, + hidden_size=128, + qk_dim_factor=0.5, + v_dim_factor=1.0, + num_hidden_layers=2, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + num_labels=3, + num_choices=4, + scope=None, + chunkwise_kernel="chunkwise--native_autograd", + sequence_kernel="native_sequence__native", + step_kernel="native", + tie_word_embeddings=False, + ): + self.parent = parent + self.num_heads = num_heads + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.qk_dim_factor = qk_dim_factor + self.v_dim_factor = v_dim_factor + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.bos_token_id = vocab_size - 1 + self.eos_token_id = vocab_size - 1 + self.pad_token_id = vocab_size - 1 + self.chunkwise_kernel = chunkwise_kernel + self.sequence_kernel = sequence_kernel + self.step_kernel = step_kernel + self.tie_word_embeddings = tie_word_embeddings + + def get_large_model_config(self): + cfg = xLSTMConfig.from_pretrained("NX-AI/xLSTM-7b") + return cfg + + def prepare_config_and_inputs(self, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return ( + config, + input_ids, + None, + sequence_labels, + token_labels, + choice_labels, + ) + + def get_config(self): + cfg = xLSTMConfig( + num_heads=self.num_heads, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + qk_dim_factor=self.qk_dim_factor, + v_dim_factor=self.v_dim_factor, + n_positions=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + use_cache=True, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + chunkwise_kernel=self.chunkwise_kernel, + sequence_kernel=self.sequence_kernel, + step_kernel=self.step_kernel, + tie_word_embeddings=self.tie_word_embeddings, + ) + # this is needed for compatibility with generic tests + # cfg.hidden_size = cfg.embedding_dim + # cfg.num_hidden_layers = cfg.num_blocks + return cfg + + def prepare_config_and_inputs_for_common(self): + ( + config, + input_ids, + _, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + inputs_dict = {"input_ids": input_ids} + return config, inputs_dict + + +@require_torch +class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (xLSTMModel, xLSTMForCausalLM) if is_torch_available() else () + all_generative_model_classes = (xLSTMForCausalLM,) if is_torch_available() else () + has_attentions = False # xLSTM does not support attentions + fx_compatible = False + test_torchscript = False + test_model_parallel = False + test_pruning = False + test_head_masking = False # xLSTM does not have attention heads + + pipeline_model_mapping = ( + {"feature-extraction": xLSTMModel, "text-generation": xLSTMForCausalLM} if is_torch_available() else {} + ) + + def setUp(self): + self.model_tester = xLSTMModelTester(self) + self.config_tester = ConfigTester( + self, config_class=xLSTMConfig, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] + ) + + def test_initialization(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config=config) + for name, param in model.named_parameters(): + if "D" in name: + if param.requires_grad: + # check if it's a ones like + self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5)) + + @unittest.skip(reason="xLSTM has no tied weights") + def test_tied_weights_keys(self): + pass + + @unittest.skip(reason="xLSTM cache slicing test case is an edge case") + def test_generate_without_input_ids(self): + pass + + @unittest.skip(reason="xLSTM cache slicing test case is an edge case") + @parameterized.expand([("greedy", 1), ("beam search", 2)]) + def test_generate_from_inputs_embeds(self, _, num_beams): + pass + + @unittest.skip(reason="xLSTM cache slicing test case is an edge case") + def test_greedy_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip(reason="xLSTM cache slicing is interacting with beam search") + def test_beam_search_generate_dict_outputs_use_cache(self): + pass + + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with torch.no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, xLSTMCache): + recursive_check(tuple_object.rnn_state, dict_object.rnn_state) + elif isinstance(tuple_object, (list, tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, dict): + for tuple_iterable_value, dict_iterable_value in zip( + tuple_object.values(), dict_object.values() + ): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + torch.allclose(tuple_object, dict_object, atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" + f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + ), + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + +@require_torch +@slow +@require_read_token +class xLSTMIntegrationTest(unittest.TestCase): + def setUp(self): + self.model_id = "NX-AI/xLSTM-7b" + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False) + self.prompt = ("[INST]Write a hello world program in C++.",) + + def test_simple_generate(self, device): + """ + Simple generate test to avoid regressions. + Note: state-spaces (cuda) implementation and pure torch implementation + have irreconciliable differences as of now, which will cause this test to fail + in an environment with state-spaces installed. + """ + tokenizer = self.tokenizer + tokenizer.pad_token_id = tokenizer.eos_token_id + + model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16) + model.to(device) + input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to( + device + ) + + out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30) + output_sentence = tokenizer.decode(out[0]) + ground_truth_sentence = """[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include \n\n""" + self.assertEqual(output_sentence, ground_truth_sentence) + + def test_batched_equivalence_with_cache(self): + """ + Verifies that batched generation matches individual generation. + Important because of the specific caching mechanism + statefulness of the xLSTM model. + Depending on precision and devices, differences can be observed from generation to generation. + """ + tokenizer = self.tokenizer + prompt = [ + "[INST]Write C#.[/INST]", + "[INST]Write a hello world in C++.[/INST]", + "[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]", + ] + + model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device) + tokenizer.pad_token_id = tokenizer.eos_token_id + # batched generation + tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device) + batched_gen = model.generate(**tokenized_prompts, max_new_tokens=30, use_cache=True) + batched_output = tokenizer.batch_decode(batched_gen, skip_special_tokens=True) + + # individual generation + + for index_gen, individual_prompt in enumerate(prompt): + inputs = tokenizer(individual_prompt, return_tensors="pt", padding="longest").to(torch_device) + individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True) + individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0] + self.assertEqual(individual_output[:100], batched_output[index_gen][:100]) + + def test_batched_equivalence_without_cache(self): + """ + Verifies that batched generation matches individual generation without cache. + Important because of the specific caching mechanism + statefulness of the xLSTM model. + Depending on precision and devices, differences can be observed from generation to generation. + """ + tokenizer = self.tokenizer + prompt = [ + "[INST]Write C#.[/INST]", + "[INST]Write a hello world in C++.[/INST]", + "[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]", + ] + + model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device) + tokenizer.pad_token_id = tokenizer.eos_token_id + # batched generation + tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device) + batched_gen = model.generate(**tokenized_prompts, max_new_tokens=30, use_cache=True) + batched_output = tokenizer.batch_decode(batched_gen, skip_special_tokens=True) + + # individual generation + + for index_gen, individual_prompt in enumerate(prompt): + inputs = tokenizer(individual_prompt, return_tensors="pt", padding="longest").to(torch_device) + individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True) + individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0] + self.assertEqual(individual_output[:100], batched_output[index_gen][:100]) + + @require_torch_gpu + def test_xlstm_block_train_vs_eval_equivalence(self): + # Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63 + # Credit to zhixuan-lin + + B, T, D = 4, 512, 768 + dtype = torch.bfloat16 + config = xLSTMConfig(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1) + + torch.manual_seed(42) + with torch.amp.autocast(device_type="cuda", dtype=dtype): + with torch.no_grad(): + block = xLSTMBlock(config.to_xlstm_block_config(), layer_idx=0).to("cuda") + hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda") + + block.train() + out_train = block(hidden_states) + + block.eval() + out_eval = block(hidden_states) + + self.assertTrue(torch.allclose(out_train, out_eval, atol=1e-3)) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index d8bd32847e..c358e6a393 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -32,6 +32,7 @@ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS) CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING SPECIAL_CASES_TO_ALLOW = { + "xLSTMConfig": ["add_out_norm", "chunkwise_kernel", "sequence_kernel", "step_kernel"], "Ernie4_5Config": ["tie_word_embeddings"], "Ernie4_5_MoeConfig": ["tie_word_embeddings"], "Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"],