Add xlstm model (#39665)
* Add xLSTM cleanly with optimizations. * Fix style. * Fix modeling test. * Make xLSTM package optional. * Fix: Update torch version check. * Fix: Bad variable naming in test. * Fix: Import structure cleaning with Ruff. * Fix: Update docstrings. * Fix: Mitigate unused config attr tests by explicit usage. * Fix: Skip tests, if xlstm library is not installed. * Feat: Enable longer context window for inference by chunking. * Fix: Make training test pass by lowering target accuracy. * Chore: Increase test verbosity for failing generation test. * Update docs/source/en/model_doc/xlstm.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Fix: Make xlstm available even without CUDA. * Chore: Remove unnecessary import. * Fix: Remove BOS insertion. * Chore: Improve xLSTMCache documentation. * Integrate basic xLSTM fallback code. * Chore: Remove unnecessary import. * Chore: Remove duplicate LayerNorm. * chore: update copyright, minor reformatting * fix: refactor mLSTMStateType due to missing torch import * fix: add missing import * Chore: Replace einops. * fix: apply ruff formatting * fix: run `make fix-copies` to re-generate dummy_pt_objects.py * fix: make type hints Python 3.9 compatible * fix: remove obsolete import * fix: remove obsolete method from docs * chore: remove obsolete `force_bos_token_insert` from config * Chore: Remove duplicated xLSTMCache class. * Fix: Formatting of modeling_xlstm.py * Chore: Remove xlstm package requirement from test. Re-add update_rnn_state. * Fix: Update xLSTMCache docstring. * Feat: Add proper initialization of xLSTM. * Chore: Re-format files. * Chore: Adapt format. * Fix: xLSTMCache import restructuring. * Fix: Add __all__ lists to modeling and configuration files. * Chore: Reformat. * Fix: Remove unnecessary update_rnn_state function. * Fix: Undo test accuracy quickfix. * Fix: Update copyright year, remvoe config copy. * Chore: Flatten all internal configs to xLSTMConfig. * Fix: Unused config variables check. * Chore: Remove unnecessary imports. * Fix: Unify xlstm cache argument from batch_size to max_batch_size. * Chore: Remove bad default arg value for xLSTMCache. * Chore: Rename core configuration arguments to HF default in xLSTM. * Chore: Fix formatting. * Fix: xLSTM Cache config access. * Fix: Update xlstm tests for config update. * Feat: Re-add embbeding_dim, num_blocks config options for compat with xLSTM-7B. * Fix: Configuration xLSTM python3.9 syntax. * Fix: Difference to main in test_utils.py assertion. * Fix: Bad syntax in xlstm config for python3.9. * Fix: xLSTMConfig docstring. * Fix: xLSTMConfig docstring. * Fix typing issues in xLSTM and BeiT, Paligemma. * Fix: Exclude xLSTM from test cache utils. * Chore: Fix style. * Chore: Fix format. * Chore: Remove unnecessary LayerNorm, NormLayer layer abstractions. * Chore: Remove asserts and replace with ValueErrors. * Chore: Update __init__.py structure of xLSTM. * Chore: Clean xLSTM initialization of weights. * Fix index names in modeling_xlstm.py * Update xlstm model test typing annotations. * Fix: Remove all asserts. * Revert changes to the main __init__.py * Fix: Move xLSTMCache to modeling_xlstm.py * Fix: Remove xLSTMForCausalLM mapping from modeling_auto.py * Remove xLSTMCache from dummy_pt_objects.py * Fix: Remove extended torchdynamo compilation check integrating cuda graph captures. * Revert test_cache_utils.py xLSTM change. * Fix: Move xLSTM init functions before init call. * Remove xLSTMCache from generation utils. * Fix: Clean xLSTM init functionality for recursive calls. * Fix: Move xLSTMCache before its first call. * Fix formatting. * Add partial docstring for xLSTMModel forward. * Fix xLSTMCache docstring in xLSTMModel. * Remove xLSTMCache from public documentation. Update auto_docstring. * Remove all agressive shape comments * style * Fix names * simplify * remove output_hidden_states * Update modeling_xlstm.py * Update modeling_xlstm.py * Update test_modeling_xlstm.py * Update modeling_xlstm.py * Update modeling_xlstm.py * fix * fix * style * style --------- Co-authored-by: Korbinian Poeppel <korbinian.poeppel@nx-ai.com> Co-authored-by: Korbinian Pöppel <37810656+kpoeppel@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Sebastian Böck <sebastian.boeck@nx-ai.com> Co-authored-by: Korbinian Poeppel <poeppel@ml.jku.at>
This commit is contained in:
@@ -697,6 +697,8 @@
|
|||||||
title: XLM-V
|
title: XLM-V
|
||||||
- local: model_doc/xlnet
|
- local: model_doc/xlnet
|
||||||
title: XLNet
|
title: XLNet
|
||||||
|
- local: model_doc/xlstm
|
||||||
|
title: xLSTM
|
||||||
- local: model_doc/yoso
|
- local: model_doc/yoso
|
||||||
title: YOSO
|
title: YOSO
|
||||||
- local: model_doc/zamba
|
- local: model_doc/zamba
|
||||||
|
|||||||
47
docs/source/en/model_doc/xlstm.md
Normal file
47
docs/source/en/model_doc/xlstm.md
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
<!--Copyright 2025 NXAI GmbH. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
|
||||||
|
# xLSTM
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The xLSTM model was proposed in [xLSTM: Extended Long Short-Term Memory](https://openreview.net/forum?id=ARAxPPIAhq) by Maximilian Beck*, Korbinian Pöppel*, Markus Spanring, Andreas Auer, Oleksandra Prudnikova, Michael Kopp, Günter Klambauer, Johannes Brandstetter and Sepp Hochreiter.
|
||||||
|
xLSTM updates the original LSTM architecture to be competitive with Transformer models by introducing exponential gating, matrix memory expansion, and parallelizable training and ingestion.
|
||||||
|
|
||||||
|
The [7B model](https://hf.co/NX-AI/xLSTM-7b) variant was trained by the xLSTM team Maximilian Beck, Korbinian Pöppel, Phillip Lippe, Richard Kurle, Patrick Blies, Sebastian Böck and Sepp Hochreiter at NXAI.
|
||||||
|
|
||||||
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
|
*In the 1990s, the constant error carousel and gating were introduced as the central ideas of the Long Short-Term Memory (LSTM). Since then, LSTMs have stood the test of time and contributed to numerous deep learning success stories, in particular they constituted the first Large Language Models (LLMs). However, the advent of the Transformer technology with parallelizable self-attention at its core marked the dawn of a new era, outpacing LSTMs at scale. We now raise a simple question: How far do we get in language modeling when scaling LSTMs to billions of parameters, leveraging the latest techniques from modern LLMs, but mitigating known limitations of LSTMs? Firstly, we introduce exponential gating with appropriate normalization and stabilization techniques. Secondly, we modify the LSTM memory structure, obtaining: (i) sLSTM with a scalar memory, a scalar update, and new memory mixing, (ii) mLSTM that is fully parallelizable with a matrix memory and a covariance update rule. Integrating these LSTM extensions into residual block backbones yields xLSTM blocks that are then residually stacked into xLSTM architectures. Exponential gating and modified memory structures boost xLSTM capabilities to perform favorably when compared to state-of-the-art Transformers and State Space Models, both in performance and scaling.*
|
||||||
|
|
||||||
|
This model was contributed by [NX-AI](https://huggingface.co/NX-AI).
|
||||||
|
The original code can be found [here](https://github.com/NX-AI/xlstm).
|
||||||
|
|
||||||
|
|
||||||
|
## xLSTMConfig
|
||||||
|
|
||||||
|
[[autodoc]] xLSTMConfig
|
||||||
|
|
||||||
|
## xLSTMModel
|
||||||
|
|
||||||
|
[[autodoc]] xLSTMModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## xLSTMLMHeadModel
|
||||||
|
|
||||||
|
[[autodoc]] xLSTMForCausalLM
|
||||||
|
- forward
|
||||||
@@ -353,6 +353,7 @@ if TYPE_CHECKING:
|
|||||||
from .xlm_roberta import *
|
from .xlm_roberta import *
|
||||||
from .xlm_roberta_xl import *
|
from .xlm_roberta_xl import *
|
||||||
from .xlnet import *
|
from .xlnet import *
|
||||||
|
from .xlstm import *
|
||||||
from .xmod import *
|
from .xmod import *
|
||||||
from .yolos import *
|
from .yolos import *
|
||||||
from .yoso import *
|
from .yoso import *
|
||||||
|
|||||||
@@ -410,6 +410,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
|||||||
("xlm-roberta", "XLMRobertaConfig"),
|
("xlm-roberta", "XLMRobertaConfig"),
|
||||||
("xlm-roberta-xl", "XLMRobertaXLConfig"),
|
("xlm-roberta-xl", "XLMRobertaXLConfig"),
|
||||||
("xlnet", "XLNetConfig"),
|
("xlnet", "XLNetConfig"),
|
||||||
|
("xlstm", "xLSTMConfig"),
|
||||||
("xmod", "XmodConfig"),
|
("xmod", "XmodConfig"),
|
||||||
("yolos", "YolosConfig"),
|
("yolos", "YolosConfig"),
|
||||||
("yoso", "YosoConfig"),
|
("yoso", "YosoConfig"),
|
||||||
@@ -832,6 +833,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
|||||||
("xlnet", "XLNet"),
|
("xlnet", "XLNet"),
|
||||||
("xls_r", "XLS-R"),
|
("xls_r", "XLS-R"),
|
||||||
("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
|
("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
|
||||||
|
("xlstm", "xLSTM"),
|
||||||
("xmod", "X-MOD"),
|
("xmod", "X-MOD"),
|
||||||
("yolos", "YOLOS"),
|
("yolos", "YOLOS"),
|
||||||
("yoso", "YOSO"),
|
("yoso", "YOSO"),
|
||||||
|
|||||||
@@ -379,6 +379,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("xlm-roberta", "XLMRobertaModel"),
|
("xlm-roberta", "XLMRobertaModel"),
|
||||||
("xlm-roberta-xl", "XLMRobertaXLModel"),
|
("xlm-roberta-xl", "XLMRobertaXLModel"),
|
||||||
("xlnet", "XLNetModel"),
|
("xlnet", "XLNetModel"),
|
||||||
|
("xlstm", "xLSTMModel"),
|
||||||
("xmod", "XmodModel"),
|
("xmod", "XmodModel"),
|
||||||
("yolos", "YolosModel"),
|
("yolos", "YolosModel"),
|
||||||
("yoso", "YosoModel"),
|
("yoso", "YosoModel"),
|
||||||
@@ -474,6 +475,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
|||||||
("xlm-roberta", "XLMRobertaForMaskedLM"),
|
("xlm-roberta", "XLMRobertaForMaskedLM"),
|
||||||
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
|
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
|
||||||
("xlnet", "XLNetLMHeadModel"),
|
("xlnet", "XLNetLMHeadModel"),
|
||||||
|
("xlstm", "xLSTMForCausalLM"),
|
||||||
("xmod", "XmodForMaskedLM"),
|
("xmod", "XmodForMaskedLM"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -692,6 +694,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("xlm-roberta", "XLMRobertaForCausalLM"),
|
("xlm-roberta", "XLMRobertaForCausalLM"),
|
||||||
("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
|
("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
|
||||||
("xlnet", "XLNetLMHeadModel"),
|
("xlnet", "XLNetLMHeadModel"),
|
||||||
|
("xlstm", "xLSTMForCausalLM"),
|
||||||
("xmod", "XmodForCausalLM"),
|
("xmod", "XmodForCausalLM"),
|
||||||
("zamba", "ZambaForCausalLM"),
|
("zamba", "ZambaForCausalLM"),
|
||||||
("zamba2", "Zamba2ForCausalLM"),
|
("zamba2", "Zamba2ForCausalLM"),
|
||||||
|
|||||||
@@ -718,6 +718,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
|||||||
"XLNetTokenizerFast" if is_tokenizers_available() else None,
|
"XLNetTokenizerFast" if is_tokenizers_available() else None,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
("xlstm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
(
|
(
|
||||||
"xmod",
|
"xmod",
|
||||||
(
|
(
|
||||||
|
|||||||
31
src/transformers/models/xlstm/__init__.py
Normal file
31
src/transformers/models/xlstm/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# Copyright 2025 NXAI GmbH. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import (
|
||||||
|
OptionalDependencyNotAvailable,
|
||||||
|
_LazyModule,
|
||||||
|
is_torch_available,
|
||||||
|
)
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from configuration_xlstm import *
|
||||||
|
from modeling_xlstm import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||||
302
src/transformers/models/xlstm/configuration_xlstm.py
Normal file
302
src/transformers/models/xlstm/configuration_xlstm.py
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
# Copyright 2025 NXAI GmbH. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""xLSTM configuration."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...utils import is_xlstm_available, logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_xlstm_available():
|
||||||
|
from xlstm.xlstm_large.model import (
|
||||||
|
BackendModeType,
|
||||||
|
ChunkwiseKernelType,
|
||||||
|
DtypeType,
|
||||||
|
SequenceKernelType,
|
||||||
|
StepKernelType,
|
||||||
|
WeightModeType,
|
||||||
|
round_up_to_next_multiple_of,
|
||||||
|
xLSTMLargeConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
external_xlstm = True
|
||||||
|
else:
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
BackendModeType = Literal["train", "train_with_padding", "inference"]
|
||||||
|
ChunkwiseKernelType = Literal[
|
||||||
|
"chunkwise--native_autograd",
|
||||||
|
"parallel--native_autograd",
|
||||||
|
]
|
||||||
|
DtypeType = Literal["float32", "bfloat16", "float16"]
|
||||||
|
SequenceKernelType = Literal["native_sequence__native"]
|
||||||
|
StepKernelType = Literal["native"]
|
||||||
|
WeightModeType = Literal["single", "fused"]
|
||||||
|
|
||||||
|
def round_up_to_next_multiple_of(x: int, multiple_of: int) -> int:
|
||||||
|
"""Rounds up x to the next multiple of multiple_of."""
|
||||||
|
return int(((x + multiple_of - 1) // multiple_of) * multiple_of)
|
||||||
|
|
||||||
|
external_xlstm = False
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class xLSTMConfig(PretrainedConfig):
|
||||||
|
"""
|
||||||
|
This is the configuration class to store the configuration of a [`xLSTM`]. It is used to instantiate a xLSTM
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the xLSTM-7b [NX-AI/xLSTM-7b](https://huggingface.co/NX-AI/xLSTM-7b) model.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (int, optional, *optional*, defaults to 50304):
|
||||||
|
Vocabulary size of the xLSTM model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`xLSTMModel`]. Defaults to the GPT2-NeoX tokenizer size.
|
||||||
|
hidden_size (int, optional, *optional*, defaults to 4096):
|
||||||
|
Dimensionality of the embeddings or hidden states.
|
||||||
|
embedding_dim (int, optional, *optional*, defaults to 4096):
|
||||||
|
Dimensionality of the embeddings or hidden states, use hidde_size if None.
|
||||||
|
num_hidden_layers (int, optional, *optional*, defaults to 32):
|
||||||
|
Number of blocks of the xLSTM model.
|
||||||
|
num_blocks (int, optional, *optional*, defaults to 32):
|
||||||
|
Number of blocks of the xLSTM model, use num_hidden_layers if None.
|
||||||
|
num_heads (int, optional, *optional*, defaults to 8):
|
||||||
|
Number of heads for the xLSTM Layer/Cell.
|
||||||
|
use_bias (bool, optional, *optional*, defaults to `False`):
|
||||||
|
Whether to use biases in the xLSTM model.
|
||||||
|
norm_reduction_force_float32 (bool, optional, *optional*, defaults to `True`):
|
||||||
|
Whether to force the float32 norm reduction op to be done in fp32 precision.
|
||||||
|
tie_word_embeddings (bool, optional, *optional*, defaults to `False`):
|
||||||
|
Whether to tie word embeddings to the lm head weights.
|
||||||
|
add_out_norm (bool, optional, *optional*, defaults to `True`):
|
||||||
|
Whether to add an output norm after the blocks before the LMHead.
|
||||||
|
norm_eps (float, optional, *optional*, defaults to 1e-06):
|
||||||
|
Norm eps for RMSNorm and Layer Norm.
|
||||||
|
qk_dim_factor (float, optional, *optional*, defaults to 0.5):
|
||||||
|
Scale factor for the query and key dimension.
|
||||||
|
v_dim_factor (float, optional, *optional*, defaults to 1.0):
|
||||||
|
Scale factor for the value dimension.
|
||||||
|
chunkwise_kernel (ChunkwiseKernelType, optional, *optional*, defaults to `"chunkwise--native_autograd"`):
|
||||||
|
Kernel type for chunkwise processing mode.
|
||||||
|
sequence_kernel (SequenceKernelType, optional, *optional*, defaults to `"native_sequence__native"`):
|
||||||
|
Kernel type for sequence processing mode.
|
||||||
|
step_kernel (StepKernelType, optional, *optional*, defaults to `"native"`):
|
||||||
|
Kernel type for step processing mode.
|
||||||
|
mode (BackendModeType, optional, *optional*, defaults to `"inference"`):
|
||||||
|
Operation mode (inference is needed for generation).
|
||||||
|
chunk_size (int, optional, *optional*, defaults to 64):
|
||||||
|
Internal chunk size.
|
||||||
|
return_last_states (bool, optional, *optional*, defaults to `True`):
|
||||||
|
If to return the last states / cache internally. Needed as True for generation.
|
||||||
|
autocast_kernel_dtype (DtypeType, optional, *optional*, defaults to `"bfloat16"`):
|
||||||
|
Kernel dtype for the states.
|
||||||
|
eps (float, optional, *optional*, defaults to 1e-06):
|
||||||
|
Epsilon for the mLSTM cell post norm.
|
||||||
|
inference_state_dtype (DtypeType, optional, *optional*, defaults to `"float32"`):
|
||||||
|
Kernel dtype for states in inference.
|
||||||
|
ffn_proj_factor (float, optional, *optional*, defaults to 2.667):
|
||||||
|
Size factor of the post-up projection gated Feed Forward network.
|
||||||
|
ffn_round_up_to_multiple_of (int, optional, *optional*, defaults to 64):
|
||||||
|
Size factor round value of the post-up projection gated Feed Forward network.
|
||||||
|
gate_soft_cap (float, optional, *optional*, defaults to 15.0):
|
||||||
|
Gate soft cap scale.
|
||||||
|
output_logit_soft_cap (float, optional, *optional*, defaults to 30.0):
|
||||||
|
Output logit soft cap scale.
|
||||||
|
weight_mode (`Literal`, *optional*, defaults to `"single"`):
|
||||||
|
Whether parallel linear layers are separated or fused (single).
|
||||||
|
use_cache (bool, optional, *optional*, defaults to `True`):
|
||||||
|
Whether to use the cache (xLSTMCache).
|
||||||
|
pad_token_id (int, optional, *optional*, defaults to 1):
|
||||||
|
Pad token id needed for generation.
|
||||||
|
bos_token_id (int, optional, *optional*, defaults to 0):
|
||||||
|
BOS token id needed for generation.
|
||||||
|
eos_token_id (int, optional, *optional*, defaults to 2):
|
||||||
|
EOS token id needed for generation.
|
||||||
|
max_inference_chunksize (int, optional, *optional*, defaults to 16384):
|
||||||
|
Limit the chunk size for inference to save memory.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import xLSTMConfig, xLSTMModel
|
||||||
|
|
||||||
|
>>> # Initializing a xLSTM configuration
|
||||||
|
>>> configuration = xLSTMConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model (with random weights) from the configuration
|
||||||
|
>>> model = xLSTMModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "xlstm"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int = 50304,
|
||||||
|
hidden_size: int = 4096,
|
||||||
|
embedding_dim: Optional[int] = None,
|
||||||
|
num_hidden_layers: Optional[int] = 32,
|
||||||
|
num_blocks: Optional[int] = None,
|
||||||
|
num_heads: int = 8,
|
||||||
|
use_bias: bool = False,
|
||||||
|
norm_reduction_force_float32: bool = True,
|
||||||
|
tie_word_embeddings: bool = False,
|
||||||
|
add_out_norm: bool = True,
|
||||||
|
norm_eps: float = 1e-6,
|
||||||
|
# mlstm_layer
|
||||||
|
qk_dim_factor: float = 0.5,
|
||||||
|
v_dim_factor: float = 1.0,
|
||||||
|
# mlstm backend
|
||||||
|
chunkwise_kernel: ChunkwiseKernelType = "chunkwise--native_autograd",
|
||||||
|
sequence_kernel: SequenceKernelType = "native_sequence__native",
|
||||||
|
step_kernel: StepKernelType = "native",
|
||||||
|
# nedded to enable generation
|
||||||
|
mode: BackendModeType = "inference",
|
||||||
|
chunk_size: int = 64,
|
||||||
|
# needed to be true for generation
|
||||||
|
return_last_states: bool = True,
|
||||||
|
autocast_kernel_dtype: DtypeType = "bfloat16",
|
||||||
|
eps: float = 1e-6,
|
||||||
|
inference_state_dtype: DtypeType = "float32",
|
||||||
|
# feedforward
|
||||||
|
ffn_proj_factor: float = 2.667,
|
||||||
|
ffn_round_up_to_multiple_of: int = 64,
|
||||||
|
# capping
|
||||||
|
gate_soft_cap: float = 15.0,
|
||||||
|
output_logit_soft_cap: float = 30.0,
|
||||||
|
# weights
|
||||||
|
weight_mode: WeightModeType = "single",
|
||||||
|
# HF interface
|
||||||
|
use_cache: bool = True,
|
||||||
|
pad_token_id: int = 1,
|
||||||
|
bos_token_id: int = 0,
|
||||||
|
eos_token_id: int = 2,
|
||||||
|
max_inference_chunksize: int = 16384,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size if hidden_size is not None else embedding_dim
|
||||||
|
self.embedding_dim = embedding_dim if embedding_dim is not None else hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers if num_hidden_layers is not None else num_blocks
|
||||||
|
self.num_blocks = num_blocks if num_blocks is not None else num_hidden_layers
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.use_bias = use_bias
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
self.add_out_norm = add_out_norm
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
self.norm_reduction_force_float32 = norm_reduction_force_float32
|
||||||
|
# mlstm_layer
|
||||||
|
self.qk_dim_factor = qk_dim_factor
|
||||||
|
self.v_dim_factor = v_dim_factor
|
||||||
|
# mlstm backend
|
||||||
|
self.chunkwise_kernel = chunkwise_kernel
|
||||||
|
self.sequence_kernel = sequence_kernel
|
||||||
|
self.step_kernel = step_kernel
|
||||||
|
self.mode = mode
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.return_last_states = return_last_states
|
||||||
|
self.autocast_kernel_dtype = autocast_kernel_dtype
|
||||||
|
self.eps = eps
|
||||||
|
self.inference_state_dtype = inference_state_dtype
|
||||||
|
# feedforward
|
||||||
|
self.ffn_proj_factor = ffn_proj_factor
|
||||||
|
self.ffn_round_up_to_multiple_of = ffn_round_up_to_multiple_of
|
||||||
|
# capping
|
||||||
|
self.gate_soft_cap = gate_soft_cap
|
||||||
|
self.output_logit_soft_cap = output_logit_soft_cap
|
||||||
|
self.weight_mode = weight_mode
|
||||||
|
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.max_inference_chunksize = max_inference_chunksize
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def qk_dim(self):
|
||||||
|
return round_up_to_next_multiple_of(
|
||||||
|
self.hidden_size * self.qk_dim_factor,
|
||||||
|
multiple_of=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def v_dim(self):
|
||||||
|
return round_up_to_next_multiple_of(
|
||||||
|
self.hidden_size * self.v_dim_factor,
|
||||||
|
multiple_of=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def qk_head_dim(self):
|
||||||
|
return self.qk_dim // self.num_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def v_head_dim(self):
|
||||||
|
return self.v_dim // self.num_heads
|
||||||
|
|
||||||
|
def to_xlstm_block_config(self):
|
||||||
|
if external_xlstm:
|
||||||
|
return xLSTMLargeConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
embedding_dim=self.hidden_size,
|
||||||
|
num_blocks=self.num_hidden_layers,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
use_bias=self.use_bias,
|
||||||
|
add_out_norm=self.add_out_norm,
|
||||||
|
norm_eps=self.norm_eps,
|
||||||
|
norm_reduction_force_float32=self.norm_reduction_force_float32,
|
||||||
|
# mlstm_layer
|
||||||
|
qk_dim_factor=self.qk_dim_factor,
|
||||||
|
v_dim_factor=self.v_dim_factor,
|
||||||
|
# mlstm backend
|
||||||
|
chunkwise_kernel=self.chunkwise_kernel,
|
||||||
|
sequence_kernel=self.sequence_kernel,
|
||||||
|
step_kernel=self.step_kernel,
|
||||||
|
mode=self.mode,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
return_last_states=self.return_last_states,
|
||||||
|
autocast_kernel_dtype=self.autocast_kernel_dtype,
|
||||||
|
eps=self.eps,
|
||||||
|
inference_state_dtype=self.inference_state_dtype,
|
||||||
|
# feedforward
|
||||||
|
ffn_proj_factor=self.ffn_proj_factor,
|
||||||
|
ffn_round_up_to_multiple_of=self.ffn_round_up_to_multiple_of,
|
||||||
|
# capping
|
||||||
|
gate_soft_cap=self.gate_soft_cap,
|
||||||
|
output_logit_soft_cap=self.output_logit_soft_cap,
|
||||||
|
weight_mode=self.weight_mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["xLSTMConfig"]
|
||||||
1623
src/transformers/models/xlstm/modeling_xlstm.py
Normal file
1623
src/transformers/models/xlstm/modeling_xlstm.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -270,6 +270,7 @@ from .import_utils import (
|
|||||||
is_uroman_available,
|
is_uroman_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
is_vptq_available,
|
is_vptq_available,
|
||||||
|
is_xlstm_available,
|
||||||
is_yt_dlp_available,
|
is_yt_dlp_available,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
torch_only_method,
|
torch_only_method,
|
||||||
|
|||||||
@@ -588,6 +588,12 @@ def is_causal_conv1d_available():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_xlstm_available():
|
||||||
|
if is_torch_available():
|
||||||
|
return _is_package_available("xlstm")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_mambapy_available():
|
def is_mambapy_available():
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
return _is_package_available("mambapy")
|
return _is_package_available("mambapy")
|
||||||
|
|||||||
0
tests/models/xlstm/__init__.py
Normal file
0
tests/models/xlstm/__init__.py
Normal file
371
tests/models/xlstm/test_modeling_xlstm.py
Normal file
371
tests/models/xlstm/test_modeling_xlstm.py
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
# Copyright 2025 NXAI GmbH. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, is_torch_available, xLSTMConfig
|
||||||
|
from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
xLSTMForCausalLM,
|
||||||
|
xLSTMModel,
|
||||||
|
)
|
||||||
|
from transformers.models.xlstm.modeling_xlstm import xLSTMBlock, xLSTMCache
|
||||||
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
|
||||||
|
else:
|
||||||
|
is_torch_greater_or_equal_than_2_2 = False
|
||||||
|
|
||||||
|
|
||||||
|
class xLSTMModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
num_heads=2,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_labels=True,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=128,
|
||||||
|
qk_dim_factor=0.5,
|
||||||
|
v_dim_factor=1.0,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
type_sequence_label_size=2,
|
||||||
|
num_labels=3,
|
||||||
|
num_choices=4,
|
||||||
|
scope=None,
|
||||||
|
chunkwise_kernel="chunkwise--native_autograd",
|
||||||
|
sequence_kernel="native_sequence__native",
|
||||||
|
step_kernel="native",
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.qk_dim_factor = qk_dim_factor
|
||||||
|
self.v_dim_factor = v_dim_factor
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.num_choices = num_choices
|
||||||
|
self.scope = scope
|
||||||
|
self.bos_token_id = vocab_size - 1
|
||||||
|
self.eos_token_id = vocab_size - 1
|
||||||
|
self.pad_token_id = vocab_size - 1
|
||||||
|
self.chunkwise_kernel = chunkwise_kernel
|
||||||
|
self.sequence_kernel = sequence_kernel
|
||||||
|
self.step_kernel = step_kernel
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
|
||||||
|
def get_large_model_config(self):
|
||||||
|
cfg = xLSTMConfig.from_pretrained("NX-AI/xLSTM-7b")
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
sequence_labels = None
|
||||||
|
token_labels = None
|
||||||
|
choice_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
None,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
cfg = xLSTMConfig(
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
qk_dim_factor=self.qk_dim_factor,
|
||||||
|
v_dim_factor=self.v_dim_factor,
|
||||||
|
n_positions=self.max_position_embeddings,
|
||||||
|
type_vocab_size=self.type_vocab_size,
|
||||||
|
use_cache=True,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
chunkwise_kernel=self.chunkwise_kernel,
|
||||||
|
sequence_kernel=self.sequence_kernel,
|
||||||
|
step_kernel=self.step_kernel,
|
||||||
|
tie_word_embeddings=self.tie_word_embeddings,
|
||||||
|
)
|
||||||
|
# this is needed for compatibility with generic tests
|
||||||
|
# cfg.hidden_size = cfg.embedding_dim
|
||||||
|
# cfg.num_hidden_layers = cfg.num_blocks
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
_,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = self.prepare_config_and_inputs()
|
||||||
|
inputs_dict = {"input_ids": input_ids}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class xLSTMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (xLSTMModel, xLSTMForCausalLM) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (xLSTMForCausalLM,) if is_torch_available() else ()
|
||||||
|
has_attentions = False # xLSTM does not support attentions
|
||||||
|
fx_compatible = False
|
||||||
|
test_torchscript = False
|
||||||
|
test_model_parallel = False
|
||||||
|
test_pruning = False
|
||||||
|
test_head_masking = False # xLSTM does not have attention heads
|
||||||
|
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{"feature-extraction": xLSTMModel, "text-generation": xLSTMForCausalLM} if is_torch_available() else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = xLSTMModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(
|
||||||
|
self, config_class=xLSTMConfig, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=config)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if "D" in name:
|
||||||
|
if param.requires_grad:
|
||||||
|
# check if it's a ones like
|
||||||
|
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
|
||||||
|
|
||||||
|
@unittest.skip(reason="xLSTM has no tied weights")
|
||||||
|
def test_tied_weights_keys(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="xLSTM cache slicing test case is an edge case")
|
||||||
|
def test_generate_without_input_ids(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="xLSTM cache slicing test case is an edge case")
|
||||||
|
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||||
|
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="xLSTM cache slicing test case is an edge case")
|
||||||
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="xLSTM cache slicing is interacting with beam search")
|
||||||
|
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_model_outputs_equivalence(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||||
|
with torch.no_grad():
|
||||||
|
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||||
|
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||||
|
|
||||||
|
def recursive_check(tuple_object, dict_object):
|
||||||
|
if isinstance(tuple_object, xLSTMCache):
|
||||||
|
recursive_check(tuple_object.rnn_state, dict_object.rnn_state)
|
||||||
|
elif isinstance(tuple_object, (list, tuple)):
|
||||||
|
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||||
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||||
|
elif isinstance(tuple_object, dict):
|
||||||
|
for tuple_iterable_value, dict_iterable_value in zip(
|
||||||
|
tuple_object.values(), dict_object.values()
|
||||||
|
):
|
||||||
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||||
|
elif tuple_object is None:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(tuple_object, dict_object, atol=1e-5),
|
||||||
|
msg=(
|
||||||
|
"Tuple and dict output are not equal. Difference:"
|
||||||
|
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||||
|
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||||
|
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
recursive_check(tuple_output, dict_output)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
@require_read_token
|
||||||
|
class xLSTMIntegrationTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.model_id = "NX-AI/xLSTM-7b"
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False)
|
||||||
|
self.prompt = ("[INST]Write a hello world program in C++.",)
|
||||||
|
|
||||||
|
def test_simple_generate(self, device):
|
||||||
|
"""
|
||||||
|
Simple generate test to avoid regressions.
|
||||||
|
Note: state-spaces (cuda) implementation and pure torch implementation
|
||||||
|
have irreconciliable differences as of now, which will cause this test to fail
|
||||||
|
in an environment with state-spaces installed.
|
||||||
|
"""
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16)
|
||||||
|
model.to(device)
|
||||||
|
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
|
||||||
|
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
|
||||||
|
output_sentence = tokenizer.decode(out[0])
|
||||||
|
ground_truth_sentence = """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include <iostream>\n\n"""
|
||||||
|
self.assertEqual(output_sentence, ground_truth_sentence)
|
||||||
|
|
||||||
|
def test_batched_equivalence_with_cache(self):
|
||||||
|
"""
|
||||||
|
Verifies that batched generation matches individual generation.
|
||||||
|
Important because of the specific caching mechanism + statefulness of the xLSTM model.
|
||||||
|
Depending on precision and devices, differences can be observed from generation to generation.
|
||||||
|
"""
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
prompt = [
|
||||||
|
"[INST]Write C#.[/INST]",
|
||||||
|
"[INST]Write a hello world in C++.[/INST]",
|
||||||
|
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||||
|
]
|
||||||
|
|
||||||
|
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
# batched generation
|
||||||
|
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||||
|
batched_gen = model.generate(**tokenized_prompts, max_new_tokens=30, use_cache=True)
|
||||||
|
batched_output = tokenizer.batch_decode(batched_gen, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# individual generation
|
||||||
|
|
||||||
|
for index_gen, individual_prompt in enumerate(prompt):
|
||||||
|
inputs = tokenizer(individual_prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||||
|
individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True)
|
||||||
|
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
|
||||||
|
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
|
||||||
|
|
||||||
|
def test_batched_equivalence_without_cache(self):
|
||||||
|
"""
|
||||||
|
Verifies that batched generation matches individual generation without cache.
|
||||||
|
Important because of the specific caching mechanism + statefulness of the xLSTM model.
|
||||||
|
Depending on precision and devices, differences can be observed from generation to generation.
|
||||||
|
"""
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
prompt = [
|
||||||
|
"[INST]Write C#.[/INST]",
|
||||||
|
"[INST]Write a hello world in C++.[/INST]",
|
||||||
|
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||||
|
]
|
||||||
|
|
||||||
|
model = xLSTMForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
# batched generation
|
||||||
|
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||||
|
batched_gen = model.generate(**tokenized_prompts, max_new_tokens=30, use_cache=True)
|
||||||
|
batched_output = tokenizer.batch_decode(batched_gen, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# individual generation
|
||||||
|
|
||||||
|
for index_gen, individual_prompt in enumerate(prompt):
|
||||||
|
inputs = tokenizer(individual_prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||||
|
individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True)
|
||||||
|
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
|
||||||
|
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_xlstm_block_train_vs_eval_equivalence(self):
|
||||||
|
# Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63
|
||||||
|
# Credit to zhixuan-lin
|
||||||
|
|
||||||
|
B, T, D = 4, 512, 768
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
config = xLSTMConfig(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
with torch.amp.autocast(device_type="cuda", dtype=dtype):
|
||||||
|
with torch.no_grad():
|
||||||
|
block = xLSTMBlock(config.to_xlstm_block_config(), layer_idx=0).to("cuda")
|
||||||
|
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
block.train()
|
||||||
|
out_train = block(hidden_states)
|
||||||
|
|
||||||
|
block.eval()
|
||||||
|
out_eval = block(hidden_states)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(out_train, out_eval, atol=1e-3))
|
||||||
@@ -32,6 +32,7 @@ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
|
|||||||
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
|
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
|
||||||
|
|
||||||
SPECIAL_CASES_TO_ALLOW = {
|
SPECIAL_CASES_TO_ALLOW = {
|
||||||
|
"xLSTMConfig": ["add_out_norm", "chunkwise_kernel", "sequence_kernel", "step_kernel"],
|
||||||
"Ernie4_5Config": ["tie_word_embeddings"],
|
"Ernie4_5Config": ["tie_word_embeddings"],
|
||||||
"Ernie4_5_MoeConfig": ["tie_word_embeddings"],
|
"Ernie4_5_MoeConfig": ["tie_word_embeddings"],
|
||||||
"Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"],
|
"Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"],
|
||||||
|
|||||||
Reference in New Issue
Block a user