Add the Bamba Model (#34982)
* initial commit for PR Co-authored-by: Gabe Goodhart <gabe.l.hart@gmail.com> * rename dynamic cache Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * add more unit tests Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * add integration test Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * add integration test Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * Add modular bamba file * Remove trainer changes from unrelated PR * Modify modular and cofig to get model running * Fix some CI errors and beam search * Fix a plethora of bugs from CI/docs/etc * Add bamba to models with special caches * Updat to newer mamba PR for mamba sublayer * fix test_left_padding_compatibility Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * fix style Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * fix remaining tests Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * missed this test Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * ran make style Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * move slow tag to integration obj Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * make style Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * address comments Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * fix modular Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * left out one part of modular Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * change model Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * Make Rotary modular as well * Update bamba.md Added overview, update Model inference card and added config * Update bamba.md * Update bamba.md * Update bamba.md Minor fixes * Add docs for config and model back Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> * Add warning when using fast kernels * replaced generate example Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * Address comments from PR Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> * Propagate attention fixes Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> * Fix attention interfaces to the new API Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> * Fix API for decoder layer Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> * Remove extra weights Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> --------- Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> Co-authored-by: Gabe Goodhart <gabe.l.hart@gmail.com> Co-authored-by: Antoni Viros i Martin <aviros@ibm.com> Co-authored-by: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Co-authored-by: Antoni Viros <ani300@gmail.com>
This commit is contained in:
committed by
GitHub
parent
9a94dfe123
commit
9613933b02
@@ -322,6 +322,8 @@
|
|||||||
sections:
|
sections:
|
||||||
- local: model_doc/albert
|
- local: model_doc/albert
|
||||||
title: ALBERT
|
title: ALBERT
|
||||||
|
- local: model_doc/bamba
|
||||||
|
title: Bamba
|
||||||
- local: model_doc/bart
|
- local: model_doc/bart
|
||||||
title: BART
|
title: BART
|
||||||
- local: model_doc/barthez
|
- local: model_doc/barthez
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| [AriaText](model_doc/aria_text) | ✅ | ❌ | ❌ |
|
| [AriaText](model_doc/aria_text) | ✅ | ❌ | ❌ |
|
||||||
| [Audio Spectrogram Transformer](model_doc/audio-spectrogram-transformer) | ✅ | ❌ | ❌ |
|
| [Audio Spectrogram Transformer](model_doc/audio-spectrogram-transformer) | ✅ | ❌ | ❌ |
|
||||||
| [Autoformer](model_doc/autoformer) | ✅ | ❌ | ❌ |
|
| [Autoformer](model_doc/autoformer) | ✅ | ❌ | ❌ |
|
||||||
|
| [Bamba](model_doc/bamba) | ✅ | ❌ | ❌ |
|
||||||
| [Bark](model_doc/bark) | ✅ | ❌ | ❌ |
|
| [Bark](model_doc/bark) | ✅ | ❌ | ❌ |
|
||||||
| [BART](model_doc/bart) | ✅ | ✅ | ✅ |
|
| [BART](model_doc/bart) | ✅ | ✅ | ✅ |
|
||||||
| [BARThez](model_doc/barthez) | ✅ | ✅ | ✅ |
|
| [BARThez](model_doc/barthez) | ✅ | ✅ | ✅ |
|
||||||
|
|||||||
64
docs/source/en/model_doc/bamba.md
Normal file
64
docs/source/en/model_doc/bamba.md
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Bamba
|
||||||
|
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Bamba-9B is a decoder-only language model based on the [Mamba-2](https://github.com/state-spaces/mamba) architecture and is designed to handle a wide range of text generation tasks. It is trained from scratch using a two-stage training approach. In the first stage, the model is trained on 2 trillion tokens from the Dolma v1.7 dataset. In the second stage, it undergoes additional training on 200 billion tokens, leveraging a carefully curated blend of high-quality data to further refine its performance and enhance output quality.
|
||||||
|
|
||||||
|
Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-model-stack/bamba).
|
||||||
|
|
||||||
|
## BambaConfig
|
||||||
|
|
||||||
|
| Model | Params | # Layers | Hidden Dim. | Attention Heads | GQA | KV Heads | Context Length | Tied Embeddings |
|
||||||
|
|-------------------|--------------|----------|-------------|-----------------|-----|----------|----------------|------------------|
|
||||||
|
| Bamba | 9B (9.78B) | 32 | 4096 | 32 | Yes | 8 | 4096 | True |
|
||||||
|
|
||||||
|
[[autodoc]] BambaConfig
|
||||||
|
|
||||||
|
<!---
|
||||||
|
## Usage Tips
|
||||||
|
|
||||||
|
Tips:
|
||||||
|
|
||||||
|
- The architecture is based on Mamba-2 models.
|
||||||
|
|
||||||
|
## BambaModel
|
||||||
|
|
||||||
|
[[autodoc]] BambaModel
|
||||||
|
- forward
|
||||||
|
-->
|
||||||
|
|
||||||
|
## BambaForCausalLM
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B")
|
||||||
|
|
||||||
|
message = ["Mamba is a snake with following properties "]
|
||||||
|
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
|
||||||
|
response = model.generate(**inputs, max_new_tokens=64)
|
||||||
|
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
|
||||||
|
```
|
||||||
|
|
||||||
|
[[autodoc]] BambaForCausalLM
|
||||||
|
- forward
|
||||||
|
|
||||||
|
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
|
||||||
@@ -39,6 +39,7 @@ FlashAttention-2 is experimental and may change considerably in future versions.
|
|||||||
FlashAttention-2 is currently supported for the following architectures:
|
FlashAttention-2 is currently supported for the following architectures:
|
||||||
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
|
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
|
||||||
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
|
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
|
||||||
|
* [Bamba](https://huggingface.co/docs/transformers/model_doc/bamba#transformers.BambaModel)
|
||||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||||
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
|
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
|
||||||
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
|
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
|
||||||
@@ -220,6 +221,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel)
|
* [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel)
|
||||||
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
|
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
|
||||||
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
|
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
|
||||||
|
* [Bamba](https://huggingface.co/docs/transformers/model_doc/bamba#transformers.BambaModel)
|
||||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||||
* [Beit](https://huggingface.co/docs/transformers/model_doc/beit#transformers.BeitModel)
|
* [Beit](https://huggingface.co/docs/transformers/model_doc/beit#transformers.BeitModel)
|
||||||
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
|
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
|
||||||
|
|||||||
@@ -193,6 +193,7 @@ _import_structure = {
|
|||||||
"AutoTokenizer",
|
"AutoTokenizer",
|
||||||
],
|
],
|
||||||
"models.autoformer": ["AutoformerConfig"],
|
"models.autoformer": ["AutoformerConfig"],
|
||||||
|
"models.bamba": ["BambaConfig"],
|
||||||
"models.bark": [
|
"models.bark": [
|
||||||
"BarkCoarseConfig",
|
"BarkCoarseConfig",
|
||||||
"BarkConfig",
|
"BarkConfig",
|
||||||
@@ -1540,6 +1541,13 @@ else:
|
|||||||
"AutoformerPreTrainedModel",
|
"AutoformerPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.bamba"].extend(
|
||||||
|
[
|
||||||
|
"BambaForCausalLM",
|
||||||
|
"BambaModel",
|
||||||
|
"BambaPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.bark"].extend(
|
_import_structure["models.bark"].extend(
|
||||||
[
|
[
|
||||||
"BarkCausalModel",
|
"BarkCausalModel",
|
||||||
@@ -5104,6 +5112,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.autoformer import (
|
from .models.autoformer import (
|
||||||
AutoformerConfig,
|
AutoformerConfig,
|
||||||
)
|
)
|
||||||
|
from .models.bamba import BambaConfig
|
||||||
from .models.bark import (
|
from .models.bark import (
|
||||||
BarkCoarseConfig,
|
BarkCoarseConfig,
|
||||||
BarkConfig,
|
BarkConfig,
|
||||||
@@ -6493,6 +6502,7 @@ if TYPE_CHECKING:
|
|||||||
AutoformerModel,
|
AutoformerModel,
|
||||||
AutoformerPreTrainedModel,
|
AutoformerPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.bamba import BambaForCausalLM, BambaModel, BambaPreTrainedModel
|
||||||
from .models.bark import (
|
from .models.bark import (
|
||||||
BarkCausalModel,
|
BarkCausalModel,
|
||||||
BarkCoarseModel,
|
BarkCoarseModel,
|
||||||
|
|||||||
@@ -1693,6 +1693,7 @@ class GenerationMixin:
|
|||||||
self._supports_cache_class
|
self._supports_cache_class
|
||||||
and "jamba" not in self.__class__.__name__.lower()
|
and "jamba" not in self.__class__.__name__.lower()
|
||||||
and "zamba" not in self.__class__.__name__.lower()
|
and "zamba" not in self.__class__.__name__.lower()
|
||||||
|
and "bamba" not in self.__class__.__name__.lower()
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_cache_for_generation(
|
def _prepare_cache_for_generation(
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from . import (
|
|||||||
audio_spectrogram_transformer,
|
audio_spectrogram_transformer,
|
||||||
auto,
|
auto,
|
||||||
autoformer,
|
autoformer,
|
||||||
|
bamba,
|
||||||
bark,
|
bark,
|
||||||
bart,
|
bart,
|
||||||
barthez,
|
barthez,
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("aria_text", "AriaTextConfig"),
|
("aria_text", "AriaTextConfig"),
|
||||||
("audio-spectrogram-transformer", "ASTConfig"),
|
("audio-spectrogram-transformer", "ASTConfig"),
|
||||||
("autoformer", "AutoformerConfig"),
|
("autoformer", "AutoformerConfig"),
|
||||||
|
("bamba", "BambaConfig"),
|
||||||
("bark", "BarkConfig"),
|
("bark", "BarkConfig"),
|
||||||
("bart", "BartConfig"),
|
("bart", "BartConfig"),
|
||||||
("beit", "BeitConfig"),
|
("beit", "BeitConfig"),
|
||||||
@@ -337,6 +338,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("aria_text", "AriaText"),
|
("aria_text", "AriaText"),
|
||||||
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
|
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
|
||||||
("autoformer", "Autoformer"),
|
("autoformer", "Autoformer"),
|
||||||
|
("bamba", "Bamba"),
|
||||||
("bark", "Bark"),
|
("bark", "Bark"),
|
||||||
("bart", "BART"),
|
("bart", "BART"),
|
||||||
("barthez", "BARThez"),
|
("barthez", "BARThez"),
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("aria_text", "AriaTextModel"),
|
("aria_text", "AriaTextModel"),
|
||||||
("audio-spectrogram-transformer", "ASTModel"),
|
("audio-spectrogram-transformer", "ASTModel"),
|
||||||
("autoformer", "AutoformerModel"),
|
("autoformer", "AutoformerModel"),
|
||||||
|
("bamba", "BambaModel"),
|
||||||
("bark", "BarkModel"),
|
("bark", "BarkModel"),
|
||||||
("bart", "BartModel"),
|
("bart", "BartModel"),
|
||||||
("beit", "BeitModel"),
|
("beit", "BeitModel"),
|
||||||
@@ -471,6 +472,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
[
|
[
|
||||||
# Model for Causal LM mapping
|
# Model for Causal LM mapping
|
||||||
("aria_text", "AriaTextForCausalLM"),
|
("aria_text", "AriaTextForCausalLM"),
|
||||||
|
("bamba", "BambaForCausalLM"),
|
||||||
("bart", "BartForCausalLM"),
|
("bart", "BartForCausalLM"),
|
||||||
("bert", "BertLMHeadModel"),
|
("bert", "BertLMHeadModel"),
|
||||||
("bert-generation", "BertGenerationDecoder"),
|
("bert-generation", "BertGenerationDecoder"),
|
||||||
|
|||||||
28
src/transformers/models/bamba/__init__.py
Normal file
28
src/transformers/models/bamba/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_bamba import *
|
||||||
|
from .modeling_bamba import *
|
||||||
|
from .processing_bamba import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||||
206
src/transformers/models/bamba/configuration_bamba.py
Normal file
206
src/transformers/models/bamba/configuration_bamba.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Bamba model configuration"""
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BambaConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`BambaModel`]. It is used to instantiate a
|
||||||
|
BambaModel model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||||
|
with defaults taken from [ibm-fms/Bamba-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/Bamba-9.8b-2.2T-hf).
|
||||||
|
|
||||||
|
The BambaModel is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU.
|
||||||
|
The checkpoints are jointly trained by IBM, Princeton, and UIUC.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 128000):
|
||||||
|
Vocabulary size of the Bamba model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`BambaModel`]
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
|
||||||
|
model has a output word embedding layer.
|
||||||
|
hidden_size (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 14336):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
|
||||||
|
Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
|
||||||
|
integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
|
||||||
|
logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
|
||||||
|
sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
|
||||||
|
significantly.
|
||||||
|
pad_token_id (`int`, *optional*, defaults to 0):
|
||||||
|
The id of the padding token.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
The id of the "beginning-of-sequence" token.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
The id of the "end-of-sequence" token.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 262144):
|
||||||
|
Max cached sequence length for the model
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
attn_layer_indices (`list`, *optional*):
|
||||||
|
Specifies the layer indices that will have full attention. Must contain values at most num_hidden_layers.
|
||||||
|
mamba_n_heads (`int`, *optional*, defaults to 128):
|
||||||
|
The number of mamba heads used in the v2 implementation.
|
||||||
|
mamba_d_head (`int`, *optional*, defaults to `"auto"`):
|
||||||
|
Head embeddding dimension size
|
||||||
|
mamba_n_groups (`int`, *optional*, defaults to 1):
|
||||||
|
The number of the mamba groups used in the v2 implementation.
|
||||||
|
mamba_d_state (`int`, *optional*, defaults to 256):
|
||||||
|
The dimension the mamba state space latents
|
||||||
|
mamba_d_conv (`int`, *optional*, defaults to 4):
|
||||||
|
The size of the mamba convolution kernel
|
||||||
|
mamba_expand (`int`, *optional*, defaults to 2):
|
||||||
|
Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
|
||||||
|
mamba_chunk_size (`int`, *optional*, defaults to 256):
|
||||||
|
The chunks in which to break the sequence when doing prefill/training
|
||||||
|
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
|
||||||
|
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "bamba"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=128000,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=14336,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
hidden_act="silu",
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
use_cache=True,
|
||||||
|
num_logits_to_keep=1,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
max_position_embeddings=262144,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
attn_layer_indices=None,
|
||||||
|
mamba_n_heads=128,
|
||||||
|
mamba_d_head="auto",
|
||||||
|
mamba_n_groups=1,
|
||||||
|
mamba_d_state=256,
|
||||||
|
mamba_d_conv=4,
|
||||||
|
mamba_expand=2,
|
||||||
|
mamba_chunk_size=256,
|
||||||
|
mamba_conv_bias=True,
|
||||||
|
mamba_proj_bias=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.attention_bias = False
|
||||||
|
self.mlp_bias = False
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.num_logits_to_keep = num_logits_to_keep
|
||||||
|
|
||||||
|
self.attn_layer_indices = attn_layer_indices
|
||||||
|
self.rope_theta = 10000.0
|
||||||
|
self.rope_scaling = None
|
||||||
|
self.partial_rotary_factor = 0.5
|
||||||
|
|
||||||
|
mamba_intermediate = mamba_expand * hidden_size
|
||||||
|
|
||||||
|
if mamba_intermediate % mamba_n_heads != 0:
|
||||||
|
raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size")
|
||||||
|
|
||||||
|
# for the mamba_v2, must satisfy the following
|
||||||
|
if mamba_d_head == "auto":
|
||||||
|
mamba_d_head = mamba_intermediate // mamba_n_heads
|
||||||
|
|
||||||
|
if mamba_d_head * mamba_n_heads != mamba_intermediate:
|
||||||
|
raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size")
|
||||||
|
|
||||||
|
self.mamba_n_heads = mamba_n_heads
|
||||||
|
self.mamba_d_head = mamba_d_head
|
||||||
|
self.mamba_n_groups = mamba_n_groups
|
||||||
|
self.mamba_d_state = mamba_d_state
|
||||||
|
self.mamba_d_conv = mamba_d_conv
|
||||||
|
self.mamba_expand = mamba_expand
|
||||||
|
self.mamba_chunk_size = mamba_chunk_size
|
||||||
|
self.mamba_conv_bias = mamba_conv_bias
|
||||||
|
self.mamba_proj_bias = mamba_proj_bias
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers_block_type(self):
|
||||||
|
return [
|
||||||
|
"attention" if (self.attn_layer_indices and i in self.attn_layer_indices) else "mamba"
|
||||||
|
for i in range(self.num_hidden_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BambaConfig"]
|
||||||
273
src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py
Normal file
273
src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from os import path
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import split_torch_state_dict_into_shards
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||||
|
|
||||||
|
from .configuration_bamba import BambaConfig
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]:
|
||||||
|
state_dict = {}
|
||||||
|
|
||||||
|
for orig_k, param in original_sd.items():
|
||||||
|
k = orig_k.replace("backbone", "model")
|
||||||
|
|
||||||
|
# for embeddings
|
||||||
|
k = k.replace("embedding", "embed_tokens")
|
||||||
|
|
||||||
|
# for mixer
|
||||||
|
k = k.replace("mixer", "mamba")
|
||||||
|
|
||||||
|
# for final layernorm
|
||||||
|
k = k.replace("norm_f", "final_layernorm")
|
||||||
|
|
||||||
|
# for block layernorm
|
||||||
|
k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
|
||||||
|
k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)
|
||||||
|
|
||||||
|
# for mlp
|
||||||
|
k = k.replace("mlp.fc2", "feed_forward.down_proj")
|
||||||
|
|
||||||
|
if "mlp.fc1" in k:
|
||||||
|
param, param2 = torch.chunk(param, 2, dim=0)
|
||||||
|
k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
|
||||||
|
state_dict[k2] = param2
|
||||||
|
k = k.replace("mlp.fc1", "feed_forward.up_proj")
|
||||||
|
|
||||||
|
if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
|
||||||
|
"out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
|
||||||
|
):
|
||||||
|
# then this must be a mamba
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# for attn
|
||||||
|
# - because mixer was replaced to mamba above
|
||||||
|
k = k.replace("mamba.out_proj", "self_attn.o_proj")
|
||||||
|
if "mamba.in_proj" in k:
|
||||||
|
m, n = param.shape
|
||||||
|
d = (m - n) // 2
|
||||||
|
param, param2, param3 = torch.split(param, [n, d, d], dim=0)
|
||||||
|
k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
|
||||||
|
state_dict[k2] = param2
|
||||||
|
k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
|
||||||
|
state_dict[k2] = param3
|
||||||
|
k = k.replace("mamba.in_proj", "self_attn.q_proj")
|
||||||
|
|
||||||
|
state_dict[k] = param
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||||
|
def convert_ssm_config_to_hf_config(
|
||||||
|
config_ssm: Dict,
|
||||||
|
**kwargs,
|
||||||
|
) -> BambaConfig:
|
||||||
|
"""Convert a config from mamba_ssm to a BambaConfig from here."""
|
||||||
|
hf_config: BambaConfig = BambaConfig(**kwargs)
|
||||||
|
|
||||||
|
hf_config.architectures = ["BambaForCausalLM"]
|
||||||
|
|
||||||
|
# Set important values from config and recalculate other resulting entries
|
||||||
|
hf_config.hidden_size = config_ssm["d_model"]
|
||||||
|
hf_config.intermediate_size = config_ssm["d_intermediate"]
|
||||||
|
hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head
|
||||||
|
hf_config.num_hidden_layers = config_ssm["n_layer"]
|
||||||
|
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
|
||||||
|
|
||||||
|
# currently this script assumes config_ssm belongs to v2
|
||||||
|
if config_ssm["ssm_cfg"].get("layer") != "Mamba2":
|
||||||
|
raise ValueError("Conversion script only supports Mamba2")
|
||||||
|
|
||||||
|
# Set attention values
|
||||||
|
attn_cfg = config_ssm.get("attn_cfg")
|
||||||
|
if attn_cfg:
|
||||||
|
assert attn_cfg["causal"], "Only support non-causal attention."
|
||||||
|
assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias."
|
||||||
|
assert not attn_cfg["out_proj_bias"], "Only support no out bias."
|
||||||
|
hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"]
|
||||||
|
hf_config.num_attention_heads = attn_cfg["num_heads"]
|
||||||
|
hf_config.num_key_value_heads = attn_cfg["num_heads_kv"]
|
||||||
|
|
||||||
|
attention_layer_indices = config_ssm.get("attn_layer_idx")
|
||||||
|
if attention_layer_indices:
|
||||||
|
hf_config.attn_layer_indices = attention_layer_indices
|
||||||
|
|
||||||
|
# Padded vocab size, mostly of 16 but 32 is also very common in different models
|
||||||
|
vocab_size = config_ssm["vocab_size"]
|
||||||
|
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
|
||||||
|
if (vocab_size % pad_vocab_size_multiple) != 0:
|
||||||
|
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
||||||
|
hf_config.vocab_size = vocab_size
|
||||||
|
|
||||||
|
return hf_config
|
||||||
|
|
||||||
|
|
||||||
|
def save_single_safetensor(
|
||||||
|
state_dict: Dict,
|
||||||
|
save_directory: str,
|
||||||
|
metadata: Dict,
|
||||||
|
):
|
||||||
|
save_file(
|
||||||
|
state_dict,
|
||||||
|
os.path.join(save_directory, SAFE_WEIGHTS_NAME),
|
||||||
|
metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_sharded_safetensors(
|
||||||
|
state_dict: Dict,
|
||||||
|
save_directory: str,
|
||||||
|
metadata: Dict,
|
||||||
|
max_shard_size: Union[int, str] = "5GB",
|
||||||
|
):
|
||||||
|
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
|
||||||
|
".safetensors", "{suffix}.safetensors"
|
||||||
|
)
|
||||||
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
|
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||||
|
)
|
||||||
|
index = {
|
||||||
|
"metadata": state_dict_split.metadata,
|
||||||
|
"weight_map": state_dict_split.tensor_to_filename,
|
||||||
|
}
|
||||||
|
# Save the index
|
||||||
|
with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
||||||
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
||||||
|
for shard_file, tensors in filename_to_tensors:
|
||||||
|
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||||
|
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||||
|
def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||||
|
mamba_ssm_checkpoint_path: str,
|
||||||
|
precision: str,
|
||||||
|
output_dir: str,
|
||||||
|
tokenizer_path: str = None,
|
||||||
|
save_model: Union[bool, str] = True,
|
||||||
|
) -> None:
|
||||||
|
# load tokenizer if provided, this will be used to set the
|
||||||
|
# token_ids in the config file
|
||||||
|
token_ids = {}
|
||||||
|
if tokenizer_path:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||||
|
for key in [
|
||||||
|
"bos_token_id",
|
||||||
|
"eos_token_id",
|
||||||
|
"pad_token_id",
|
||||||
|
]:
|
||||||
|
id = getattr(tokenizer, key, None)
|
||||||
|
if id:
|
||||||
|
token_ids[key] = id
|
||||||
|
|
||||||
|
# there are some configs unsettable by mamba_ssn config, so
|
||||||
|
# if there are changes from the defaults, have to pass them into
|
||||||
|
# the function
|
||||||
|
unsettables = {
|
||||||
|
"mamba_d_head": 64,
|
||||||
|
"mamba_d_state": 128,
|
||||||
|
"mamba_n_groups": 1,
|
||||||
|
"rms_norm_eps": 1e-5,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load and save config based on name
|
||||||
|
config_path = path.join(mamba_ssm_checkpoint_path, "config.json")
|
||||||
|
with open(config_path, "r", encoding="utf-8") as json_file:
|
||||||
|
config = json.load(json_file)
|
||||||
|
|
||||||
|
# convert the config
|
||||||
|
hf_config = convert_ssm_config_to_hf_config(
|
||||||
|
config_ssm=config,
|
||||||
|
**token_ids,
|
||||||
|
**unsettables,
|
||||||
|
)
|
||||||
|
hf_config.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
# Load state dict of the original model and transfer to hf model
|
||||||
|
state_dict = torch.load(
|
||||||
|
path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"),
|
||||||
|
map_location="cpu",
|
||||||
|
weights_only=True,
|
||||||
|
)
|
||||||
|
# FIXME: allow other parameters to pass in
|
||||||
|
state_dict = convert_state_dict_from_mamba_ssm(state_dict)
|
||||||
|
|
||||||
|
# Save new model to pytorch_dump_path
|
||||||
|
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
|
||||||
|
|
||||||
|
save_file_fn = None
|
||||||
|
if isinstance(save_model, bool) and save_model:
|
||||||
|
save_file_fn = save_single_safetensor
|
||||||
|
elif isinstance(save_model, str) and save_model == "sharded":
|
||||||
|
save_file_fn = save_sharded_safetensors
|
||||||
|
|
||||||
|
if save_file_fn:
|
||||||
|
save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"-i",
|
||||||
|
"--mamba_ssm_checkpoint_directory",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--precision",
|
||||||
|
type=str,
|
||||||
|
default="fp16",
|
||||||
|
const="fp16",
|
||||||
|
required=True,
|
||||||
|
choices=("fp32", "fp16", "bf16"),
|
||||||
|
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-t",
|
||||||
|
"--tokenizer_model_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
help="Path to a the tokenizer file.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||||
|
args.mamba2_checkpoint_directory,
|
||||||
|
args.precision,
|
||||||
|
args.output_dir,
|
||||||
|
)
|
||||||
1615
src/transformers/models/bamba/modeling_bamba.py
Normal file
1615
src/transformers/models/bamba/modeling_bamba.py
Normal file
File diff suppressed because it is too large
Load Diff
1303
src/transformers/models/bamba/modular_bamba.py
Normal file
1303
src/transformers/models/bamba/modular_bamba.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1167,6 +1167,27 @@ class AutoformerPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class BambaForCausalLM(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class BambaModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class BambaPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class BarkCausalModel(metaclass=DummyObject):
|
class BarkCausalModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -2313,6 +2313,7 @@ class GenerationTesterMixin:
|
|||||||
# 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the
|
# 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the
|
||||||
# standard cache format (e.g.gptbigcode )
|
# standard cache format (e.g.gptbigcode )
|
||||||
models_without_standard_cache = (
|
models_without_standard_cache = (
|
||||||
|
"bamba",
|
||||||
"ctrl",
|
"ctrl",
|
||||||
"fsmt",
|
"fsmt",
|
||||||
"gptbigcode",
|
"gptbigcode",
|
||||||
|
|||||||
0
tests/models/bamba/__init__.py
Normal file
0
tests/models/bamba/__init__.py
Normal file
603
tests/models/bamba/test_modeling_bamba.py
Normal file
603
tests/models/bamba/test_modeling_bamba.py
Normal file
@@ -0,0 +1,603 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Testing suite for the PyTorch Bamba model."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
require_torch,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
BambaForCausalLM,
|
||||||
|
BambaModel,
|
||||||
|
)
|
||||||
|
from transformers.models.bamba.modeling_bamba import (
|
||||||
|
HybridMambaAttentionDynamicCache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BambaModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_input_mask=True,
|
||||||
|
use_labels=True,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
intermediate_size=64,
|
||||||
|
hidden_act="silu",
|
||||||
|
attention_dropout=0.0,
|
||||||
|
attn_layer_indices=None,
|
||||||
|
attn_rotary_emb=8,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
initializer_range=0.02,
|
||||||
|
num_labels=3,
|
||||||
|
pad_token_id=0,
|
||||||
|
mamba_n_groups=1,
|
||||||
|
mamba_n_heads=16,
|
||||||
|
mamba_d_state=16,
|
||||||
|
mamba_d_conv=4,
|
||||||
|
mamba_expand=2,
|
||||||
|
mamba_chunk_size=16,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_input_mask = use_input_mask
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.attn_layer_indices = attn_layer_indices
|
||||||
|
self.attn_rotary_emb = attn_rotary_emb
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.scope = scope
|
||||||
|
self.mamba_n_groups = mamba_n_groups
|
||||||
|
self.mamba_n_heads = mamba_n_heads
|
||||||
|
self.mamba_d_state = mamba_d_state
|
||||||
|
self.mamba_d_conv = mamba_d_conv
|
||||||
|
self.mamba_expand = mamba_expand
|
||||||
|
self.mamba_chunk_size = mamba_chunk_size
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
input_mask = None
|
||||||
|
if self.use_input_mask:
|
||||||
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
|
token_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, input_ids, input_mask, token_labels
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
# Fix for SDPA tests, force at least 4 layers
|
||||||
|
if self.num_hidden_layers < 4:
|
||||||
|
self.num_hidden_layers = 4
|
||||||
|
if self.attn_layer_indices is None:
|
||||||
|
d = [x for x in range(2, self.num_hidden_layers) if self.num_hidden_layers % x == 0]
|
||||||
|
if len(d) == 0:
|
||||||
|
raise ValueError("num_hidden_layers is prime, cannot automatically set attn_layer_indices.")
|
||||||
|
d = d[-1] # get the largest divisor
|
||||||
|
self.attn_layer_indices = [x + 1 for x in range(0, self.num_hidden_layers, d)]
|
||||||
|
|
||||||
|
return BambaConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
num_key_value_heads=self.num_key_value_heads,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
attention_dropout=self.attention_dropout,
|
||||||
|
attn_layer_indices=self.attn_layer_indices,
|
||||||
|
attn_rotary_emb=self.attn_rotary_emb,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
mamba_n_groups=self.mamba_n_groups,
|
||||||
|
mamba_n_heads=self.mamba_n_heads,
|
||||||
|
mamba_d_state=self.mamba_d_state,
|
||||||
|
mamba_d_conv=self.mamba_d_conv,
|
||||||
|
mamba_expand=self.mamba_expand,
|
||||||
|
mamba_chunk_size=self.mamba_chunk_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_labels,
|
||||||
|
):
|
||||||
|
model = BambaModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask)
|
||||||
|
result = model(input_ids)
|
||||||
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_for_causal_lm(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_labels,
|
||||||
|
):
|
||||||
|
model = BambaForCausalLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||||
|
result = model(input_ids, attention_mask=input_mask)
|
||||||
|
result = model(input_ids, labels=token_labels)
|
||||||
|
result = model(input_ids)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past_large_inputs(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_labels,
|
||||||
|
):
|
||||||
|
# config.is_decoder = True
|
||||||
|
# config.add_cross_attention = True
|
||||||
|
model = BambaForCausalLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
# Attention: Jamba needs the cache to be initialized to return a cache!
|
||||||
|
past_key_values = HybridMambaAttentionDynamicCache(
|
||||||
|
config, input_ids.shape[0], model.dtype, device=model.device
|
||||||
|
)
|
||||||
|
outputs = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
past_key_values = outputs.past_key_values
|
||||||
|
|
||||||
|
# create hypothetical multiple next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||||
|
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(
|
||||||
|
next_input_ids,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
output_from_past = model(
|
||||||
|
next_tokens,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
output_hidden_states=True,
|
||||||
|
cache_position=torch.arange(
|
||||||
|
input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device
|
||||||
|
),
|
||||||
|
)["hidden_states"][0]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||||
|
|
||||||
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
BambaModel,
|
||||||
|
BambaForCausalLM,
|
||||||
|
)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
all_generative_model_classes = (BambaForCausalLM,) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{
|
||||||
|
"feature-extraction": BambaModel,
|
||||||
|
"text-generation": BambaForCausalLM,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
fx_compatible = False
|
||||||
|
|
||||||
|
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||||
|
# This is because we are hitting edge cases with the causal_mask buffer
|
||||||
|
model_split_percents = [0.5, 0.7, 0.8]
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = BambaModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=BambaConfig, hidden_size=64)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_casual_lm(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_past_with_large_inputs(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
r"""
|
||||||
|
Overriding the test_initialization test as the A_log and D params of the Bamba mixer are initialized differently
|
||||||
|
"""
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
if "A_log" in name:
|
||||||
|
A = torch.arange(1, config.mamba_n_heads + 1, dtype=torch.float32)[None, :]
|
||||||
|
self.assertTrue(torch.allclose(param.data, torch.log(A), atol=1e-5, rtol=1e-5))
|
||||||
|
elif "D" in name:
|
||||||
|
D = torch.ones(config.mamba_n_heads, dtype=torch.float32)
|
||||||
|
self.assertTrue(torch.allclose(param.data, D, atol=1e-5, rtol=1e-5))
|
||||||
|
else:
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_mismatched_shapes_have_properly_initialized_weights(self):
|
||||||
|
r"""
|
||||||
|
Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the
|
||||||
|
Bamba mixer are initialized differently and we tested that in test_initialization
|
||||||
|
"""
|
||||||
|
self.skipTest(reason="Cumbersome and redundant for Bamba")
|
||||||
|
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
r"""
|
||||||
|
Overriding the test_attention_outputs test as the Bamba model outputs attention only for its attention layers
|
||||||
|
"""
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.return_dict = True
|
||||||
|
|
||||||
|
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||||
|
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||||
|
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||||
|
|
||||||
|
expected_num_attentions = self.model_tester.num_hidden_layers - len(self.model_tester.attn_layer_indices)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs_dict["output_attentions"] = True
|
||||||
|
inputs_dict["output_hidden_states"] = False
|
||||||
|
config.return_dict = True
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
attentions = outputs.attentions
|
||||||
|
self.assertEqual(len(attentions), expected_num_attentions)
|
||||||
|
|
||||||
|
# check that output_attentions also work using config
|
||||||
|
del inputs_dict["output_attentions"]
|
||||||
|
config.output_attentions = True
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
attentions = outputs.attentions
|
||||||
|
self.assertEqual(len(attentions), expected_num_attentions)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(attentions[0].shape[-3:]),
|
||||||
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
|
)
|
||||||
|
out_len = len(outputs)
|
||||||
|
|
||||||
|
# Check attention is always last and order is fine
|
||||||
|
inputs_dict["output_attentions"] = True
|
||||||
|
inputs_dict["output_hidden_states"] = True
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
added_hidden_states = 1
|
||||||
|
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||||
|
|
||||||
|
self_attentions = outputs.attentions
|
||||||
|
|
||||||
|
self.assertEqual(len(self_attentions), expected_num_attentions)
|
||||||
|
self.assertListEqual(
|
||||||
|
list(self_attentions[0].shape[-3:]),
|
||||||
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Bamba has its own special cache type")
|
||||||
|
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||||
|
def test_new_cache_format(self, num_beams, do_sample):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_batching_equivalence(self):
|
||||||
|
# need to disable the tril input mask
|
||||||
|
orig = self.model_tester.use_input_mask
|
||||||
|
self.model_tester.use_input_mask = False
|
||||||
|
super().test_batching_equivalence()
|
||||||
|
self.model_tester.use_input_mask = orig
|
||||||
|
|
||||||
|
# essentially the same test in test_utils, just adjustment for rtol for this model
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_left_padding_compatibility(self):
|
||||||
|
# NOTE: left-padding results in small numerical differences. This is expected.
|
||||||
|
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
|
||||||
|
|
||||||
|
# First, filter out models that don't support left padding
|
||||||
|
# - The model must have generative capabilities
|
||||||
|
if len(self.all_generative_model_classes) == 0:
|
||||||
|
self.skipTest(reason="No generative architecture available for this model.")
|
||||||
|
|
||||||
|
# - The model must support padding
|
||||||
|
if not self.has_attentions:
|
||||||
|
self.skipTest(reason="This model doesn't support padding.")
|
||||||
|
|
||||||
|
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
|
||||||
|
decoder_only_classes = []
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
config, _ = self.prepare_config_and_inputs_for_generate()
|
||||||
|
if config.is_encoder_decoder:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
decoder_only_classes.append(model_class)
|
||||||
|
if len(decoder_only_classes) == 0:
|
||||||
|
self.skipTest(reason="No decoder-only architecture available for this model.")
|
||||||
|
|
||||||
|
# - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
|
||||||
|
# added support for it yet. We skip these models for now.
|
||||||
|
has_encoder_attributes = any(
|
||||||
|
attr_name
|
||||||
|
for attr_name in config.to_dict().keys()
|
||||||
|
if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
|
||||||
|
)
|
||||||
|
if has_encoder_attributes:
|
||||||
|
self.skipTest(
|
||||||
|
reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then, test left-padding
|
||||||
|
def _prepare_model_kwargs(input_ids, attention_mask, signature):
|
||||||
|
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
|
if "position_ids" in signature:
|
||||||
|
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
model_kwargs["position_ids"] = position_ids
|
||||||
|
if "cache_position" in signature:
|
||||||
|
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
|
||||||
|
model_kwargs["cache_position"] = cache_position
|
||||||
|
return model_kwargs
|
||||||
|
|
||||||
|
for model_class in decoder_only_classes:
|
||||||
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
|
||||||
|
# - for left padding we absolutely need to use an all ones
|
||||||
|
# attention mask, so we do not use the one in inputs_dict
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
signature = inspect.signature(model.forward).parameters.keys()
|
||||||
|
|
||||||
|
# no cache as some models require special cache classes to be init outside forward
|
||||||
|
model.generation_config.use_cache = False
|
||||||
|
|
||||||
|
# Without padding
|
||||||
|
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
|
||||||
|
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
|
||||||
|
|
||||||
|
# With left-padding (length 32)
|
||||||
|
# can hardcode pad_token to be 0 as we'll do attn masking anyway
|
||||||
|
pad_token_id = (
|
||||||
|
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
|
||||||
|
)
|
||||||
|
pad_size = (input_ids.shape[0], 32)
|
||||||
|
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
|
||||||
|
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
||||||
|
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
|
||||||
|
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
|
||||||
|
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
|
||||||
|
|
||||||
|
# They should result in very similar logits
|
||||||
|
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-5, rtol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
class BambaModelIntegrationTest(unittest.TestCase):
|
||||||
|
model = None
|
||||||
|
tokenizer = None
|
||||||
|
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||||
|
# Depending on the hardware we get different logits / generations
|
||||||
|
cuda_compute_capability_major_version = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
model_id = "ibm-fms/Bamba-9B"
|
||||||
|
cls.model = BambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
|
||||||
|
cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
# feels a bit forced to have to do this for the generation test
|
||||||
|
cls.tokenizer.pad_token_id = cls.model.config.pad_token_id
|
||||||
|
cls.tokenizer.padding_side = "left"
|
||||||
|
|
||||||
|
if is_torch_available() and torch.cuda.is_available():
|
||||||
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
|
def test_simple_generate(self):
|
||||||
|
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||||
|
#
|
||||||
|
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||||
|
# considering differences in hardware processing and potential deviations in generated text.
|
||||||
|
EXPECTED_TEXTS = {
|
||||||
|
# 7: "",
|
||||||
|
8: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.",
|
||||||
|
# 9: """,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.model.to(torch_device)
|
||||||
|
|
||||||
|
input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[
|
||||||
|
"input_ids"
|
||||||
|
].to(torch_device)
|
||||||
|
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
|
||||||
|
output_sentence = self.tokenizer.decode(out[0, :])
|
||||||
|
self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||||
|
|
||||||
|
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||||
|
if self.cuda_compute_capability_major_version == 8:
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.model(input_ids=input_ids, num_logits_to_keep=40).logits
|
||||||
|
|
||||||
|
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
|
||||||
|
[
|
||||||
|
149., 142., 146., 142., 143., 144., 142., 145.,
|
||||||
|
142., 146., 144., 146., 147., 147., 148., 145.,
|
||||||
|
147., 145., 145., 145., 145., 144., 144., 144.,
|
||||||
|
144., 145., 147., 146., 144., 144., 148., 147.,
|
||||||
|
148., 147., 147., 147., 146., 146., 148., 148.
|
||||||
|
], dtype=torch.bfloat16) # fmt: skip
|
||||||
|
|
||||||
|
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1)
|
||||||
|
|
||||||
|
def test_simple_batched_generate_with_padding(self):
|
||||||
|
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||||
|
#
|
||||||
|
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||||
|
# considering differences in hardware processing and potential deviations in generated text.
|
||||||
|
EXPECTED_TEXTS = {
|
||||||
|
7: [],
|
||||||
|
8: [
|
||||||
|
"<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
|
||||||
|
"!!!<|begin_of_text|>I am late! I need to get to work! I have to get to the",
|
||||||
|
],
|
||||||
|
9: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.model.to(torch_device)
|
||||||
|
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
["Hey how are you doing on this lovely evening?", "I am late! I need to"],
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(torch_device)
|
||||||
|
out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||||
|
output_sentences = self.tokenizer.batch_decode(out)
|
||||||
|
self.assertEqual(output_sentences[0], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][0])
|
||||||
|
self.assertEqual(output_sentences[1], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][1])
|
||||||
|
|
||||||
|
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||||
|
if self.cuda_compute_capability_major_version == 8:
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.model(input_ids=inputs["input_ids"]).logits
|
||||||
|
|
||||||
|
EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
|
||||||
|
[
|
||||||
|
149., 142., 146., 142., 143., 144., 142., 145.,
|
||||||
|
142., 146., 144., 146., 147., 147., 148., 145.,
|
||||||
|
147., 145., 145., 145., 145., 144., 144., 144.,
|
||||||
|
144., 145., 147., 146., 144., 144., 148., 147.,
|
||||||
|
148., 147., 147., 147., 146., 146., 148., 148.
|
||||||
|
], dtype=torch.bfloat16) # fmt: skip
|
||||||
|
|
||||||
|
EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
|
||||||
|
[
|
||||||
|
182., 178., 177., 174., 176., 176., 178., 178.,
|
||||||
|
177., 179., 176., 183., 180., 182., 179., 174.,
|
||||||
|
178., 176., 176., 175., 175., 175., 174., 173.,
|
||||||
|
174., 182., 180., 176., 177., 177., 180., 176.,
|
||||||
|
178., 177., 177., 175., 176., 177., 175., 177.
|
||||||
|
], dtype=torch.bfloat16) # fmt: skip
|
||||||
|
|
||||||
|
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1)
|
||||||
|
torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1)
|
||||||
@@ -34,6 +34,9 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
|
|||||||
SPECIAL_CASES_TO_ALLOW = {
|
SPECIAL_CASES_TO_ALLOW = {
|
||||||
# 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264).
|
# 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264).
|
||||||
# periods and offsers are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`.
|
# periods and offsers are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`.
|
||||||
|
"BambaConfig": [
|
||||||
|
"attn_layer_indices",
|
||||||
|
],
|
||||||
"JambaConfig": [
|
"JambaConfig": [
|
||||||
"max_position_embeddings",
|
"max_position_embeddings",
|
||||||
"attn_layer_offset",
|
"attn_layer_offset",
|
||||||
|
|||||||
Reference in New Issue
Block a user