Add codestral mamba2 (#32080)
* add new model like * draft cuda forward - mismatched keys (sharding on conv1) * match keys successfully * fix split * get generation/forward running (wrong gens, norm?) * :update * some refactoring * fixes * works up until copy to cache * fix * update * NON WORKING VERSION * version that work? * nit * fix config * fix conversion script * working cuda forward * nit * update * simplifcation * make mamba slow simple work * no einops * todo * fix style * no einops * update fix no einsum * nit * remove einops * bug: scan_output differs strongly * add rms norm option * fix fast + slow generation with and w/o cache ✔️ * draft integration tests * remove a big chunk of the einsum * fix slow, fast generations, without any einsum * fix copies * fix structure * fix up modeling and tests * fix tests * clamping is indeed worse * recover mamba2 cache test * fix copies * no cache position (yet) * fix tf tests * fix matmul for generate * fixup * skip cache tests for now * [run-slow]mamba2 * tune out hidden states for padding * test batched generation * propagate attention mask changes * fix past length * fix integration test * style * address comments * update readme * add mamba2 version check * fix tests * [run-slow]mamba2 * skip edge tests * [run-slow]mamba2 * last fixup * [run-slow]mamba2 * update README --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
@@ -438,6 +438,8 @@
|
||||
title: MADLAD-400
|
||||
- local: model_doc/mamba
|
||||
title: Mamba
|
||||
- local: model_doc/mamba2
|
||||
title: mamba2
|
||||
- local: model_doc/marian
|
||||
title: MarianMT
|
||||
- local: model_doc/markuplm
|
||||
|
||||
@@ -194,6 +194,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [M2M100](model_doc/m2m_100) | ✅ | ❌ | ❌ |
|
||||
| [MADLAD-400](model_doc/madlad-400) | ✅ | ✅ | ✅ |
|
||||
| [Mamba](model_doc/mamba) | ✅ | ❌ | ❌ |
|
||||
| [mamba2](model_doc/mamba2) | ✅ | ❌ | ❌ |
|
||||
| [Marian](model_doc/marian) | ✅ | ✅ | ✅ |
|
||||
| [MarkupLM](model_doc/markuplm) | ✅ | ❌ | ❌ |
|
||||
| [Mask2Former](model_doc/mask2former) | ✅ | ❌ | ❌ |
|
||||
|
||||
106
docs/source/en/model_doc/mamba2.md
Normal file
106
docs/source/en/model_doc/mamba2.md
Normal file
@@ -0,0 +1,106 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Mamba 2
|
||||
|
||||
## Overview
|
||||
|
||||
The Mamba2 model was proposed in [Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality](https://arxiv.org/abs/2405.21060) by Tri Dao and Albert Gu. It is a State Space Model similar to Mamba 1, with better performances in a simplified architecture.
|
||||
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*While Transformers have been the main architecture behind deep learning's success in language modeling, state-space models (SSMs) such as Mamba have recently been shown to match or outperform Transformers at small to medium scale. We show that these families of models are actually quite closely related, and develop a rich framework of theoretical connections between SSMs and variants of attention, connected through various decompositions of a well-studied class of structured semiseparable matrices. Our state space duality (SSD) framework allows us to design a new architecture (Mamba-2) whose core layer is an a refinement of Mamba's selective SSM that is 2-8X faster, while continuing to be competitive with Transformers on language modeling.*
|
||||
|
||||
Tips:
|
||||
|
||||
This version should support all implementations of Mamba 2, and in particular [Mamba-2 codestral](https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1) from Mistral AI. In particular, mamba 2 codestral was released with a number of `groups` equal to 8, which can be thought intuitively as similar to the number of kv heads in an attention-based model.
|
||||
This model has two different forward passes, `torch_forward` or `cuda_kernels_forward`. The latter uses the original cuda kernels if they are found in your environment, and is slower on the prefill i.e. requires a "warmup run" due to high cpu overhead, see [here](https://github.com/state-spaces/mamba/issues/389#issuecomment-2171755306) and [also here](https://github.com/state-spaces/mamba/issues/355#issuecomment-2147597457). Without compilation, the `torch_forward` implementation is faster by a factor 3 to 4. Further, there are no positional embeddings in this model, but there is an `attention_mask` and a specific logic to mask out hidden states in two places in the case of batched generation, see [here](https://github.com/state-spaces/mamba/issues/66#issuecomment-1863563829) as well. Due to this, in addition to the reimplementation of mamba2 kernels, batched generation and cached generation are expected to have slight discrepancies. Further, the results given by the cuda kernels or the torch forward are expected to be slightly different. The SSM algorithm heavily relies on tensor contractions, which have matmul equivalents but the order of operations is slightly different, making the difference greater at smaller precisions.
|
||||
Another note, shutdown of hidden states corresponding to padding tokens is done in 2 places and mostly has been tested with left-padding. Right-padding will propagate noise down the line and is not guaranteed to yield satisfactory results. `tokenizer.padding_side = "left"` ensures you are using the correct padding side.
|
||||
|
||||
This model was contributed by [Molbap](https://huggingface.co/Molbap), with tremendous help from [Anton Vlasjuk](https://github.com/vasqu).
|
||||
The original code can be found [here](https://github.com/state-spaces/mamba).
|
||||
|
||||
|
||||
# Usage
|
||||
|
||||
### A simple generation example:
|
||||
```python
|
||||
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
|
||||
model = MambaForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
|
||||
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
|
||||
|
||||
out = model.generate(input_ids, max_new_tokens=10)
|
||||
print(tokenizer.batch_decode(out))
|
||||
```
|
||||
|
||||
Here's a draft script for finetuning:
|
||||
```python
|
||||
from trl import SFTTrainer
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, Mamba2ForCausalLM, TrainingArguments
|
||||
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left" #enforce padding side left
|
||||
|
||||
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
|
||||
dataset = load_dataset("Abirate/english_quotes", split="train")
|
||||
# Without CUDA kernels, batch size of 2 occupies one 80GB device
|
||||
# but precision can be reduced.
|
||||
# Experiments and trials welcome!
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=2,
|
||||
logging_dir='./logs',
|
||||
logging_steps=10,
|
||||
learning_rate=2e-3
|
||||
)
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
target_modules=["embeddings", "in_proj", "out_proj"],
|
||||
task_type="CAUSAL_LM",
|
||||
bias="none"
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
peft_config=lora_config,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="quote",
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
|
||||
## Mamba2Config
|
||||
|
||||
[[autodoc]] Mamba2Config
|
||||
|
||||
## Mamba2Model
|
||||
|
||||
[[autodoc]] Mamba2Model
|
||||
- forward
|
||||
|
||||
## Mamba2LMHeadModel
|
||||
|
||||
[[autodoc]] Mamba2ForCausalLM
|
||||
- forward
|
||||
@@ -544,6 +544,7 @@ _import_structure = {
|
||||
],
|
||||
"models.m2m_100": ["M2M100Config"],
|
||||
"models.mamba": ["MambaConfig"],
|
||||
"models.mamba2": ["Mamba2Config"],
|
||||
"models.marian": ["MarianConfig"],
|
||||
"models.markuplm": [
|
||||
"MarkupLMConfig",
|
||||
@@ -2550,6 +2551,13 @@ else:
|
||||
"MambaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mamba2"].extend(
|
||||
[
|
||||
"Mamba2ForCausalLM",
|
||||
"Mamba2Model",
|
||||
"Mamba2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
|
||||
_import_structure["models.markuplm"].extend(
|
||||
[
|
||||
@@ -5240,6 +5248,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.m2m_100 import M2M100Config
|
||||
from .models.mamba import MambaConfig
|
||||
from .models.mamba2 import Mamba2Config
|
||||
from .models.marian import MarianConfig
|
||||
from .models.markuplm import (
|
||||
MarkupLMConfig,
|
||||
@@ -7046,6 +7055,11 @@ if TYPE_CHECKING:
|
||||
MambaModel,
|
||||
MambaPreTrainedModel,
|
||||
)
|
||||
from .models.mamba2 import (
|
||||
Mamba2ForCausalLM,
|
||||
Mamba2Model,
|
||||
Mamba2PreTrainedModel,
|
||||
)
|
||||
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
|
||||
from .models.markuplm import (
|
||||
MarkupLMForQuestionAnswering,
|
||||
|
||||
@@ -135,6 +135,7 @@ from . import (
|
||||
lxmert,
|
||||
m2m_100,
|
||||
mamba,
|
||||
mamba2,
|
||||
marian,
|
||||
markuplm,
|
||||
mask2former,
|
||||
|
||||
@@ -152,6 +152,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("lxmert", "LxmertConfig"),
|
||||
("m2m_100", "M2M100Config"),
|
||||
("mamba", "MambaConfig"),
|
||||
("mamba2", "Mamba2Config"),
|
||||
("marian", "MarianConfig"),
|
||||
("markuplm", "MarkupLMConfig"),
|
||||
("mask2former", "Mask2FormerConfig"),
|
||||
@@ -440,6 +441,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("m2m_100", "M2M100"),
|
||||
("madlad-400", "MADLAD-400"),
|
||||
("mamba", "Mamba"),
|
||||
("mamba2", "mamba2"),
|
||||
("marian", "Marian"),
|
||||
("markuplm", "MarkupLM"),
|
||||
("mask2former", "Mask2Former"),
|
||||
|
||||
@@ -144,6 +144,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("lxmert", "LxmertModel"),
|
||||
("m2m_100", "M2M100Model"),
|
||||
("mamba", "MambaModel"),
|
||||
("mamba2", "Mamba2Model"),
|
||||
("marian", "MarianModel"),
|
||||
("markuplm", "MarkupLMModel"),
|
||||
("mask2former", "Mask2FormerModel"),
|
||||
@@ -310,6 +311,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("luke", "LukeForMaskedLM"),
|
||||
("lxmert", "LxmertForPreTraining"),
|
||||
("mamba", "MambaForCausalLM"),
|
||||
("mamba2", "Mamba2ForCausalLM"),
|
||||
("mega", "MegaForMaskedLM"),
|
||||
("megatron-bert", "MegatronBertForPreTraining"),
|
||||
("mobilebert", "MobileBertForPreTraining"),
|
||||
@@ -394,6 +396,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
("luke", "LukeForMaskedLM"),
|
||||
("m2m_100", "M2M100ForConditionalGeneration"),
|
||||
("mamba", "MambaForCausalLM"),
|
||||
("mamba2", "Mamba2ForCausalLM"),
|
||||
("marian", "MarianMTModel"),
|
||||
("mega", "MegaForMaskedLM"),
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
@@ -472,6 +475,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("jetmoe", "JetMoeForCausalLM"),
|
||||
("llama", "LlamaForCausalLM"),
|
||||
("mamba", "MambaForCausalLM"),
|
||||
("mamba2", "Mamba2ForCausalLM"),
|
||||
("marian", "MarianForCausalLM"),
|
||||
("mbart", "MBartForCausalLM"),
|
||||
("mega", "MegaForCausalLM"),
|
||||
|
||||
@@ -270,6 +270,7 @@ else:
|
||||
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
(
|
||||
"mbart",
|
||||
|
||||
58
src/transformers/models/mamba2/__init__.py
Normal file
58
src/transformers/models/mamba2/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_mamba2": ["Mamba2Config", "Mamba2OnnxConfig"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_mamba2"] = [
|
||||
"Mamba2ForCausalLM",
|
||||
"Mamba2Model",
|
||||
"Mamba2PreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_mamba2 import Mamba2Config, Mamba2OnnxConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_mamba2 import (
|
||||
Mamba2ForCausalLM,
|
||||
Mamba2Model,
|
||||
Mamba2PreTrainedModel,
|
||||
)
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
184
src/transformers/models/mamba2/configuration_mamba2.py
Normal file
184
src/transformers/models/mamba2/configuration_mamba2.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MAMBA2 configuration"""
|
||||
|
||||
import math
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Mamba2Config(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`Mamba2Model`]. It is used to instantiate a MAMBA2
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the MAMBA2
|
||||
[state-spaces/mamba2-2.8b](https://huggingface.co/state-spaces/mamba2-2.8b) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
num_heads (`int`, *optional*, defaults to 128):
|
||||
Number of heads for the evolution matrices of mamba 2.
|
||||
head_dim (`int`, *optional*, defaults to 64):
|
||||
Dimension of each head.
|
||||
vocab_size (`int`, *optional*, defaults to 32768):
|
||||
Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Mamba2Model`].
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the embeddings and hidden states.
|
||||
state_size (`int`, *optional*, defaults to 128): shape of the state space latents.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 64):
|
||||
Number of hidden layers in the model.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon to use in the layer normalization layers.
|
||||
pad_token_id (`int`, *optional*, defaults to 1):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 0):
|
||||
The id of the beginning of sentence token in the vocabulary.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
The id of the end of sentence token in the vocabulary.
|
||||
expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
|
||||
conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
|
||||
n_groups (`int`, *optional*, defaults to 8):
|
||||
Number of groups for the evolution matrices of mamba 2.
|
||||
use_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
|
||||
use_conv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use bias in the convolution layer of the mixer block.
|
||||
hidden_act (`str`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
initializer_range (`float`, *optional*, defaults to 0.1):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
residual_in_fp32 (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
|
||||
time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
|
||||
Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
|
||||
time_step_min (`float`, *optional*, defaults to 0.001):
|
||||
Minimum `time_step` used to bound `dt_proj.bias`.
|
||||
time_step_max (`float`, *optional*, defaults to 0.1):
|
||||
Maximum `time_step` used to bound `dt_proj.bias`.
|
||||
time_step_floor (`float`, *optional*, defaults to 0.0001):
|
||||
Minimum clamping value of the `dt_proj.bias` layer initialization.
|
||||
time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`):
|
||||
Accepted range of time step values.
|
||||
rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to rescale `out_proj` weights when initializing.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the cache should be used.
|
||||
norm_before_gate (`bool`, *optional*, defaults to `True`):
|
||||
Option of cuda kernels -whether to normalize before the gate or not.
|
||||
rms_norm (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use RMS norm or not.
|
||||
chunk_size (`int`, *optional*, defaults to 256):
|
||||
Size of the chunks that will comprise the sequence.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie word embeddings or not.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Mamba2Config, Mamba2Model
|
||||
|
||||
>>> # Initializing a Mamba2 configuration
|
||||
>>> configuration = Mamba2Config()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the configuration
|
||||
>>> model = Mamba2Model(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "mamba2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads=128,
|
||||
head_dim=64,
|
||||
vocab_size=32768,
|
||||
hidden_size=4096,
|
||||
state_size=128,
|
||||
num_hidden_layers=64,
|
||||
layer_norm_epsilon=1e-5,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
expand=2,
|
||||
conv_kernel=4,
|
||||
n_groups=8,
|
||||
use_bias=False,
|
||||
use_conv_bias=True,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.1,
|
||||
residual_in_fp32=True,
|
||||
time_step_rank="auto",
|
||||
time_step_min=0.001,
|
||||
time_step_max=0.1,
|
||||
time_step_floor=1e-4,
|
||||
time_step_limit=(0.0, float("inf")),
|
||||
rescale_prenorm_residual=False,
|
||||
use_cache=True,
|
||||
norm_before_gate=True,
|
||||
rms_norm=True,
|
||||
chunk_size=256,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.state_size = state_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.conv_kernel = conv_kernel
|
||||
self.expand = expand
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.use_bias = use_bias
|
||||
self.use_conv_bias = use_conv_bias
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
|
||||
self.time_step_min = time_step_min
|
||||
self.time_step_max = time_step_max
|
||||
self.time_step_floor = time_step_floor
|
||||
self.rescale_prenorm_residual = rescale_prenorm_residual
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.use_cache = use_cache
|
||||
self.n_groups = n_groups
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.rms_norm = rms_norm
|
||||
self.state_size = state_size
|
||||
self.chunk_size = chunk_size
|
||||
self.time_step_limit = time_step_limit
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,69 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This script can be used to convert checkpoints provided in the `mamba2_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM
|
||||
|
||||
|
||||
def convert_mamba2_checkpoint_file_to_huggingface_model_file(
|
||||
mamba2_checkpoint_path: str, tokenizer_model_path: str, output_dir: str
|
||||
) -> None:
|
||||
hf_config = Mamba2Config()
|
||||
hf_model = Mamba2ForCausalLM(hf_config)
|
||||
# Load weights and config from paths
|
||||
original_state_dict = {}
|
||||
with safe_open(mamba2_checkpoint_path, framework="pt") as f:
|
||||
for k in f.keys():
|
||||
newk = k.removeprefix("model.")
|
||||
original_state_dict[newk] = f.get_tensor(k).clone()
|
||||
|
||||
hf_model.load_state_dict(original_state_dict)
|
||||
|
||||
# Save new model to pytorch_dump_path
|
||||
hf_model.to(torch.bfloat16).save_pretrained(output_dir)
|
||||
tokenizer_class = LlamaTokenizerFast
|
||||
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--mamba2_checkpoint_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a `pytorch_model.bin` mamba2_ssm checkpoint file to be converted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--tokenizer_model_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a `config.json` file corresponding to a Mamba2Config of the original mamba2_ssm model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_mamba2_checkpoint_file_to_huggingface_model_file(
|
||||
args.mamba2_checkpoint_file, args.tokenizer_model_path, args.output_dir
|
||||
)
|
||||
1082
src/transformers/models/mamba2/modeling_mamba2.py
Normal file
1082
src/transformers/models/mamba2/modeling_mamba2.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5542,6 +5542,27 @@ class MambaPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Mamba2ForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Mamba2Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Mamba2PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MarianForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -385,6 +385,21 @@ def is_mamba_ssm_available():
|
||||
return False
|
||||
|
||||
|
||||
def is_mamba_2_ssm_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
else:
|
||||
if _is_package_available("mamba_ssm"):
|
||||
import mamba_ssm
|
||||
|
||||
if version.parse(mamba_ssm.__version__) >= version.parse("2.0.4"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_causal_conv1d_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
0
tests/models/mamba2/__init__.py
Normal file
0
tests/models/mamba2/__init__.py
Normal file
387
tests/models/mamba2/test_modeling_mamba2.py
Normal file
387
tests/models/mamba2/test_modeling_mamba2.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
Mamba2ForCausalLM,
|
||||
Mamba2Model,
|
||||
)
|
||||
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
|
||||
class Mamba2ModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=14,
|
||||
num_heads=8,
|
||||
n_groups=8,
|
||||
state_size=2,
|
||||
head_dim=8,
|
||||
conv_kernel=4,
|
||||
chunk_size=8,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
hidden_act="silu",
|
||||
hidden_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
tie_word_embeddings=False,
|
||||
):
|
||||
self.parent = parent
|
||||
self.num_heads = num_heads
|
||||
self.n_groups = n_groups
|
||||
self.head_dim = head_dim
|
||||
self.state_size = state_size
|
||||
self.conv_kernel = conv_kernel
|
||||
self.chunk_size = chunk_size
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.bos_token_id = vocab_size - 1
|
||||
self.eos_token_id = vocab_size - 1
|
||||
self.pad_token_id = vocab_size - 1
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
|
||||
def get_large_model_config(self):
|
||||
return Mamba2Config.from_pretrained("revision='refs/pr/9'")
|
||||
|
||||
def prepare_config_and_inputs(
|
||||
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
|
||||
):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config(
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
None,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
)
|
||||
|
||||
def get_config(self, gradient_checkpointing=False):
|
||||
return Mamba2Config(
|
||||
head_dim=self.head_dim,
|
||||
num_heads=self.num_heads,
|
||||
n_groups=self.n_groups,
|
||||
state_size=self.state_size,
|
||||
conv_kernel=self.conv_kernel,
|
||||
chunk_size=self.chunk_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
activation_function=self.hidden_act,
|
||||
n_positions=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
use_cache=True,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
tie_word_embeddings=self.tie_word_embeddings,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
_,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = self.prepare_config_and_inputs()
|
||||
inputs_dict = {"input_ids": input_ids}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
||||
)
|
||||
@require_torch
|
||||
class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Mamba2Model, Mamba2ForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Mamba2ForCausalLM,) if is_torch_available() else ()
|
||||
has_attentions = False # Mamba does not support attentions
|
||||
fx_compatible = False # FIXME let's try to support this @molbap
|
||||
test_torchscript = False # FIXME I think this should be doable @molbap @ArthurZucker
|
||||
test_missing_keys = False
|
||||
test_model_parallel = False
|
||||
test_pruning = False
|
||||
test_head_masking = False # Mamba does not have attention heads
|
||||
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": Mamba2Model, "text-generation": Mamba2ForCausalLM} if is_torch_available() else {}
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Mamba2ModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
|
||||
)
|
||||
|
||||
def test_initialization(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
for name, param in model.named_parameters():
|
||||
if "D" in name:
|
||||
if param.requires_grad:
|
||||
# check if it's a ones like
|
||||
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
|
||||
|
||||
@unittest.skip(reason="Mamba 2 weights are not tied")
|
||||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
|
||||
def test_beam_sample_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Initialization of mamba2 fails this")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||
pass
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
with torch.no_grad():
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, Mamba2Cache): # MODIFIED PART START
|
||||
recursive_check(tuple_object.conv_states, dict_object.conv_states)
|
||||
recursive_check(tuple_object.ssm_states, dict_object.ssm_states)
|
||||
elif isinstance(tuple_object, (List, Tuple)): # MODIFIED PART END
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(
|
||||
tuple_object.values(), dict_object.values()
|
||||
):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(tuple_object, dict_object, atol=1e-5),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
recursive_check(tuple_output, dict_output)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
class Mamba2IntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model_id = "mistralai/Mamba-Codestral-7B-v0.1"
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_id, revision="refs/pr/9", from_slow=True, legacy=False
|
||||
)
|
||||
self.prompt = ("[INST]Write a hello world program in C++.",)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
(torch_device,),
|
||||
]
|
||||
)
|
||||
@slow
|
||||
@require_torch
|
||||
def test_simple_generate(self, device):
|
||||
"""
|
||||
Simple generate test to avoid regressions.
|
||||
Note: state-spaces (cuda) implementation and pure torch implementation
|
||||
have irreconciliable differences as of now, which will cause this test to fail
|
||||
in an environment with state-spaces installed.
|
||||
"""
|
||||
tokenizer = self.tokenizer
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16)
|
||||
model.to(device)
|
||||
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
|
||||
device
|
||||
)
|
||||
|
||||
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
|
||||
output_sentence = tokenizer.decode(out[0])
|
||||
ground_truth_sentence = """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include <iostream>\n\n"""
|
||||
self.assertEqual(output_sentence, ground_truth_sentence)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_batched_equivalence_with_cache(self):
|
||||
"""
|
||||
Verifies that batched generation matches individual generation.
|
||||
Important because of the specific caching mechanism + statefulness of mamba model.
|
||||
Depending on precision and devices, differences can be observed from generation to generation.
|
||||
"""
|
||||
tokenizer = self.tokenizer
|
||||
prompt = [
|
||||
"[INST]Write C#.[/INST]",
|
||||
"[INST]Write a hello world in C++.[/INST]",
|
||||
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||
]
|
||||
|
||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
# batched generation
|
||||
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||
batched_gen = model.generate(**tokenized_prompts, max_new_tokens=30, use_cache=True)
|
||||
batched_output = tokenizer.batch_decode(batched_gen, skip_special_tokens=True)
|
||||
|
||||
# individual generation
|
||||
|
||||
for index_gen, individual_prompt in enumerate(prompt):
|
||||
inputs = tokenizer(individual_prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||
individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True)
|
||||
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
|
||||
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_batched_equivalence_without_cache(self):
|
||||
"""
|
||||
Verifies that batched generation matches individual generation without cache.
|
||||
Important because of the specific caching mechanism + statefulness of mamba model.
|
||||
Depending on precision and devices, differences can be observed from generation to generation.
|
||||
"""
|
||||
tokenizer = self.tokenizer
|
||||
prompt = [
|
||||
"[INST]Write C#.[/INST]",
|
||||
"[INST]Write a hello world in C++.[/INST]",
|
||||
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||
]
|
||||
|
||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
# batched generation
|
||||
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||
batched_gen = model.generate(**tokenized_prompts, max_new_tokens=30, use_cache=True)
|
||||
batched_output = tokenizer.batch_decode(batched_gen, skip_special_tokens=True)
|
||||
|
||||
# individual generation
|
||||
|
||||
for index_gen, individual_prompt in enumerate(prompt):
|
||||
inputs = tokenizer(individual_prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||
individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True)
|
||||
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
|
||||
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
|
||||
Reference in New Issue
Block a user