From f319ba16fa89d4a48864af94d1da4abb98c4a8ab Mon Sep 17 00:00:00 2001 From: pglorio <85982602+pglorio@users.noreply.github.com> Date: Fri, 4 Oct 2024 13:28:05 -0700 Subject: [PATCH] Add Zamba (#30950) * Update index.md * Rebase * Rebase * Updates from make fixup * Update zamba.md * Batched inference * Update * Fix tests * Fix tests * Fix tests * Fix tests * Update docs/source/en/model_doc/zamba.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/zamba.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update configuration_zamba.py * Update src/transformers/models/zamba/modeling_zamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/zamba/modeling_zamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/zamba/modeling_zamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/zamba/modeling_zamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update modeling_zamba.py * Update modeling_zamba.py * Update modeling_zamba.py * Update configuration_zamba.py * Update modeling_zamba.py * Update modeling_zamba.py * Merge branch 'main' of https://github.com/Zyphra/transformers_zamba * Update ZambaForCausalLM * Update ZambaForCausalLM * Describe diffs with original mamba layer * Moved mamba init into `_init_weights` * Update index.md * Rebase * Rebase * Updates from make fixup * Update zamba.md * Batched inference * Update * Fix tests * Fix tests * Fix tests * Fix tests * Update docs/source/en/model_doc/zamba.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/zamba.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update configuration_zamba.py * Update src/transformers/models/zamba/modeling_zamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/zamba/modeling_zamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/zamba/modeling_zamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/zamba/modeling_zamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update modeling_zamba.py * Update modeling_zamba.py * Update modeling_zamba.py * Update configuration_zamba.py * Update modeling_zamba.py * Update modeling_zamba.py * Merge branch 'main' of https://github.com/Zyphra/transformers_zamba * Update ZambaForCausalLM * Moved mamba init into `_init_weights` * Update ZambaForCausalLM * Describe diffs with original mamba layer * make fixup fixes * quality test fixes * Fix Zamba model path * circleci fixes * circleci fixes * circleci fixes * circleci fixes * circleci fixes * circleci fixes * circleci fixes * circleci fixes * circleci fixes * Update * circleci fixes * fix zamba test from merge * fix ValueError for disabling mamba kernels * add HF copyright Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * shared_transf --> shared_transformer * Update src/transformers/models/zamba/modeling_zamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/zamba/modeling_zamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fixes * Move attention head dim to config * Fix circle/ci tests * Update modeling_zamba.py * apply GenerationMixin inheritance change from upstream * apply import ordering * update needed transformers version for zamba Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add contribution author * add @slow to avoid CI * Update src/transformers/models/zamba/modeling_zamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Define attention_hidden_size * Added doc for attention_head_size * trigger CI * Fix doc of attention_hidden_size * [run-slow] zamba * Fixed shared layer logic, swapped up<->gate in mlp * shared_transformer -> shared_transf * reformat HybridLayer __init__ * fix docstrings in zamba config * added definition of _get_input_ids_and_config * fixed formatting of _get_input_ids_and_config --------- Co-authored-by: root Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: root Co-authored-by: Quentin Anthony --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/zamba.md | 100 + src/transformers/__init__.py | 16 + src/transformers/generation/utils.py | 6 +- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 3 + .../models/auto/tokenization_auto.py | 7 + src/transformers/models/zamba/__init__.py | 57 + .../models/zamba/configuration_zamba.py | 224 +++ .../models/zamba/modeling_zamba.py | 1739 +++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 28 + tests/generation/test_utils.py | 12 +- tests/models/zamba/__init__.py | 0 tests/models/zamba/test_modeling_zamba.py | 736 +++++++ utils/check_config_attributes.py | 5 + utils/not_doctested.txt | 2 + 18 files changed, 2939 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/model_doc/zamba.md create mode 100644 src/transformers/models/zamba/__init__.py create mode 100644 src/transformers/models/zamba/configuration_zamba.py create mode 100644 src/transformers/models/zamba/modeling_zamba.py create mode 100644 tests/models/zamba/__init__.py create mode 100644 tests/models/zamba/test_modeling_zamba.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0e96beedea..49aa64f815 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -711,6 +711,8 @@ title: ViTMSN - local: model_doc/yolos title: YOLOS + - local: model_doc/zamba + title: Zamba - local: model_doc/zoedepth title: ZoeDepth title: Vision models diff --git a/docs/source/en/index.md b/docs/source/en/index.md index dd22d58350..32a730e6bc 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -361,6 +361,7 @@ Flax), PyTorch, and/or TensorFlow. | [XLSR-Wav2Vec2](model_doc/xlsr_wav2vec2) | ✅ | ✅ | ✅ | | [YOLOS](model_doc/yolos) | ✅ | ❌ | ❌ | | [YOSO](model_doc/yoso) | ✅ | ❌ | ❌ | +| [Zamba](model_doc/zamba) | ✅ | ❌ | ❌ | | [ZoeDepth](model_doc/zoedepth) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/zamba.md b/docs/source/en/model_doc/zamba.md new file mode 100644 index 0000000000..450b68c77d --- /dev/null +++ b/docs/source/en/model_doc/zamba.md @@ -0,0 +1,100 @@ + +# Zamba + +Zamba is a large language model (LLM) trained by Zyphra, and made available under an Apache 2.0 license. Please see the [Zyphra Hugging Face](https://huggingface.co/collections/zyphra/) repository for model weights. + +This model was contributed by [pglo](https://huggingface.co/pglo). + + +## Model details + +Zamba-7B-v1 is a hybrid between state-space models (Specifically [Mamba](https://github.com/state-spaces/mamba)) and transformer, and was trained using next-token prediction. Zamba uses a shared transformer layer after every 6 mamba blocks. It uses the [Mistral v0.1 tokenizer](https://huggingface.co/mistralai/Mistral-7B-v0.1). We came to this architecture after a series of ablations at small scales. Zamba-7B-v1 was pre-trained on 1T tokens of text and code data. + + + +## Quick start + + +### Presequities + +Zamba requires you use `transformers` version 4.46.0 or higher: +```bash +pip install transformers>=4.45.0 +``` + +In order to run optimized Mamba implementations, you first need to install `mamba-ssm` and `causal-conv1d`: +```bash +pip install mamba-ssm causal-conv1d>=1.2.0 +``` +You also have to have the model on a CUDA device. + +You can run the model not using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly lower latencies. In order to do that, you'll need to specify `use_mamba_kernels=False` when loading the model. + + +## Inference + +```python +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + +tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba-7B-v1") +model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba-7B-v1", device_map="auto", torch_dtype=torch.bfloat16) + +input_text = "A funny prompt would be " +input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") + +outputs = model.generate(**input_ids, max_new_tokens=100) +print(tokenizer.decode(outputs[0])) +``` + + +## Model card + +The model cards can be found at: +* [Zamba-7B](MODEL_CARD_ZAMBA-7B-v1.md) + + +## Issues +For issues with model output, or community discussion, please use the Hugging Face community [forum](https://huggingface.co/zyphra/zamba-7b) + + +## License + +The model weights are open-sourced via an Apache 2.0 license. + + +## ZambaConfig + +[[autodoc]] ZambaConfig + + +## ZambaModel + +[[autodoc]] ZambaModel + - forward + + +## ZambaForCausalLM + +[[autodoc]] ZambaForCausalLM + - forward + + +## ZambaForSequenceClassification + +[[autodoc]] transformers.ZambaForSequenceClassification + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 79a6bf004b..667d51cb2a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -847,6 +847,7 @@ _import_structure = { "models.xmod": ["XmodConfig"], "models.yolos": ["YolosConfig"], "models.yoso": ["YosoConfig"], + "models.zamba": ["ZambaConfig"], "models.zoedepth": ["ZoeDepthConfig"], "onnx": [], "pipelines": [ @@ -3759,6 +3760,14 @@ else: "YosoPreTrainedModel", ] ) + _import_structure["models.zamba"].extend( + [ + "ZambaForCausalLM", + "ZambaForSequenceClassification", + "ZambaModel", + "ZambaPreTrainedModel", + ] + ) _import_structure["models.zoedepth"].extend( [ "ZoeDepthForDepthEstimation", @@ -5729,6 +5738,7 @@ if TYPE_CHECKING: from .models.xmod import XmodConfig from .models.yolos import YolosConfig from .models.yoso import YosoConfig + from .models.zamba import ZambaConfig from .models.zoedepth import ZoeDepthConfig # Pipelines @@ -8126,6 +8136,12 @@ if TYPE_CHECKING: YosoModel, YosoPreTrainedModel, ) + from .models.zamba import ( + ZambaForCausalLM, + ZambaForSequenceClassification, + ZambaModel, + ZambaPreTrainedModel, + ) from .models.zoedepth import ( ZoeDepthForDepthEstimation, ZoeDepthPreTrainedModel, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 06b2654248..43eda33314 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1581,7 +1581,11 @@ class GenerationMixin: order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed for `HybridMambaAttentionDynamicCache`). """ - return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower() + return ( + self._supports_cache_class + and "jamba" not in self.__class__.__name__.lower() + and "zamba" not in self.__class__.__name__.lower() + ) def _prepare_cache_for_generation( self, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 46d1e523ac..12333c76a5 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -281,5 +281,6 @@ from . import ( xmod, yolos, yoso, + zamba, zoedepth, ) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index c21a523faf..b974daebfd 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -311,6 +311,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("xmod", "XmodConfig"), ("yolos", "YolosConfig"), ("yoso", "YosoConfig"), + ("zamba", "ZambaConfig"), ("zoedepth", "ZoeDepthConfig"), ] ) @@ -630,6 +631,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("xmod", "X-MOD"), ("yolos", "YOLOS"), ("yoso", "YOSO"), + ("zamba", "Zamba"), ("zoedepth", "ZoeDepth"), ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index a1d741017c..8b990ba1f0 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -284,6 +284,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("xmod", "XmodModel"), ("yolos", "YolosModel"), ("yoso", "YosoModel"), + ("zamba", "ZambaModel"), ] ) @@ -546,6 +547,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"), ("xlnet", "XLNetLMHeadModel"), ("xmod", "XmodForCausalLM"), + ("zamba", "ZambaForCausalLM"), ] ) @@ -974,6 +976,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("xlnet", "XLNetForSequenceClassification"), ("xmod", "XmodForSequenceClassification"), ("yoso", "YosoForSequenceClassification"), + ("zamba", "ZambaForSequenceClassification"), ] ) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 9305245f1a..f5b029414d 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -555,6 +555,13 @@ else: "AlbertTokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "zamba", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ] ) diff --git a/src/transformers/models/zamba/__init__.py b/src/transformers/models/zamba/__init__.py new file mode 100644 index 0000000000..e92890d1a7 --- /dev/null +++ b/src/transformers/models/zamba/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = { + "configuration_zamba": ["ZambaConfig"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_zamba"] = [ + "ZambaForCausalLM", + "ZambaForSequenceClassification", + "ZambaModel", + "ZambaPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_zamba import ZambaConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_zamba import ( + ZambaForCausalLM, + ZambaForSequenceClassification, + ZambaModel, + ZambaPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/zamba/configuration_zamba.py b/src/transformers/models/zamba/configuration_zamba.py new file mode 100644 index 0000000000..a6764a8260 --- /dev/null +++ b/src/transformers/models/zamba/configuration_zamba.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Zamba model configuration""" + +import math + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ZambaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ZambaModel`]. It is used to instantiate a + Zamba model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Zamba-v0.1 model. + + [Zyphra/Zamba-7B-v1](https://huggingface.co/Zyphra/Zamba-7B-v1) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Zamba model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ZambaModel`] + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 3712): + Dimension of the hidden representations. + attention_hidden_size (`int`, *optional*): + Dimension of the hidden representations of the inputs to the Attention layer. + intermediate_size (`int`, *optional*, defaults to 14848): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 76): + Number of hidden layers in the model. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + attention_head_dim (`int`, *optional*): + Dimension of the attention head in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=None`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). + n_mamba_heads (`int`, *optional*, defaults to 2): + Number of mamba heads for each mamba layer. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder. + hidden_mamba_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the mamba layer. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the + logits of the last prompt token are needed for generation. For long sequences, the logits for the entire + sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint + significantly. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + max_position_embeddings (`int`, *optional*, defaults to 4096): + This value doesn't have any real effect. The maximum sequence length that this model is intended to be + used with. It can be used with longer sequences, but performance may degrade. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + attn_layer_period (`int`, *optional*, defaults to 6): + Once in this many layers, we will have a shared attention layer + attn_layer_offset (`int`, *optional*, defaults to 4): + Offset of the shared attention layer + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and + `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if + `True` and kernels are not available + mamba_d_state (`int`, *optional*, defaults to 16): + The dimension the mamba state space latents + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size + mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj_bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj_bias`. + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block + + """ + + model_type = "zamba" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + tie_word_embeddings=True, + hidden_size=3712, + attention_hidden_size=None, + intermediate_size=14848, + num_hidden_layers=76, + num_attention_heads=16, + attention_head_dim=None, + num_key_value_heads=16, + n_mamba_heads=2, + hidden_act="gelu", + hidden_mamba_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + max_position_embeddings=4096, + attention_dropout=0.0, + attn_layer_period=6, + attn_layer_offset=4, + use_mamba_kernels=True, + mamba_d_state=16, + mamba_d_conv=4, + mamba_expand=2, + mamba_dt_rank="auto", + time_step_min=0.001, + time_step_max=0.1, + time_step_floor=1e-4, + mamba_conv_bias=True, + mamba_proj_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.tie_word_embeddings = tie_word_embeddings + self.hidden_size = hidden_size + if attention_hidden_size is None: + self.attention_hidden_size = 2 * hidden_size + else: + self.attention_hidden_size = attention_hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + if attention_head_dim is None: + self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads + else: + self.attention_head_dim = attention_head_dim + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + + self.num_key_value_heads = num_key_value_heads + self.n_mamba_heads = n_mamba_heads + self.hidden_act = hidden_act + self.hidden_mamba_act = hidden_mamba_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + self.attn_layer_period = attn_layer_period + self.attn_layer_offset = attn_layer_offset + + self.use_mamba_kernels = use_mamba_kernels + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_floor = time_step_floor + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + + self.layers_block_type = self._layers_block_type(num_hidden_layers, attn_layer_period, attn_layer_offset) + + assert ( + self.mamba_expand * self.hidden_size + ) % self.n_mamba_heads == 0, "`intermediate_size` should be divisible by `n_mamba_heads`." + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _layers_block_type(self, num_hidden_layers, attn_layer_period, attn_layer_offset): + layers = [ + "mamba", + "mamba", + "hybrid", + ] + ["hybrid" if i % attn_layer_period == attn_layer_offset else "mamba" for i in range(num_hidden_layers - 3)] + return layers diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py new file mode 100644 index 0000000000..8c6c49a3a1 --- /dev/null +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -0,0 +1,1739 @@ +# coding=utf-8 +# Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Zamba model.""" + +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_mamba_ssm_available, + is_torchdynamo_compiling, +) +from .configuration_zamba import ZambaConfig + + +if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update +else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) +) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "ZambaConfig" + + +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Zamba +class ZambaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + ZambaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(ZambaRMSNorm) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config, batch_size, dtype=torch.float16, device=None): + self.dtype = dtype + self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + intermediate_size = config.mamba_expand * config.hidden_size + ssm_state_size = config.mamba_d_state + conv_kernel_size = config.mamba_d_conv + self.n_mamba_heads = config.n_mamba_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + self._modules = {} + self._parameters = {} + self._buffers = {} + for i in range(config.num_hidden_layers): + self.conv_states += [ + torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + ] + cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size) + self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] + if self.layers_block_type[i] == "hybrid": + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.reorder_cache + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.get_seq_length + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.to_legacy_cache + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.from_legacy_cache + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + +class ZambaAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + + Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: + The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads. + The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer + (see fig. 2 in https://arxiv.org/pdf/2405.16712). + Additionally, replaced + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) + """ + + def __init__(self, config: ZambaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.attention_hidden_size = config.attention_hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.attention_head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.attention_hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.attention_hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim / 2) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size) + + attn_output = attn_output + attn_output = self.o_proj(attn_output) + attn_output = attn_output + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: +# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward +# dropped use_sliding_windows from the arguments of self._flash_attention_forward +class ZambaFlashAttention2(ZambaAttention): + """ + Zamba flash attention module. This module inherits from `ZambaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + softmax_scale = 1 / math.sqrt(self.head_dim / 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=softmax_scale, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: +# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention +class ZambaSdpaAttention(ZambaAttention): + """ + Zamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `ZambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "ZambaModel is using ZambaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + softmax_scale = 1 / math.sqrt(self.head_dim / 2) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + scale=softmax_scale, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.attention_hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +ZAMBA_ATTENTION_CLASSES = { + "eager": ZambaAttention, + "flash_attention_2": ZambaFlashAttention2, + "sdpa": ZambaSdpaAttention, +} + + +class ZambaMambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + + This module differs from `transformers.models.mamba.modeling_mamba.MambaMixer` in two ways: + - Added multi-head: the output of `self.in_proj` is split into `self.n_mamba_heads` heads, and each head + undergoes an independent forward pass, identical to the original `MambaMixer`, up until the pre-activations of + `self.out_proj`. The pre-activations, coming from different mamba heads, are then concatenated and fed into `self.out_proj`. + """ + + def __init__(self, config: ZambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = config.mamba_expand * config.hidden_size + self.time_step_rank = config.mamba_dt_rank + self.n_mamba_heads = config.n_mamba_heads + self.mamba_head_dim = self.intermediate_size // self.n_mamba_heads + self.use_conv_bias = config.mamba_conv_bias + self.use_bias = config.mamba_proj_bias + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=self.use_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.intermediate_size, + padding=self.conv_kernel_size - 1, + ) + + self.activation = config.hidden_mamba_act + self.act = ACT2FN[config.hidden_mamba_act] + + self.use_fast_kernels = config.use_mamba_kernels + + # projection of the input hidden states + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias) + # weight associated to the selective projection used to make dt, B and C input dependent + # each mamba head is processed independently + self.x_proj_weight = nn.Parameter( + ( + torch.zeros( + self.n_mamba_heads, + self.time_step_rank + self.ssm_state_size * 2, + self.mamba_head_dim, + ) + ) + ) + # time step projection (discretization) + self.dt_proj_weight = nn.Parameter( + (torch.zeros(self.n_mamba_heads, self.mamba_head_dim, self.time_step_rank) - 0.5) + * 2 + / self.time_step_rank**0.5 + ) + self.dt_proj_bias = nn.Parameter(torch.zeros(self.n_mamba_heads, self.mamba_head_dim)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = A.expand(self.intermediate_size, -1).contiguous() + self.A_log = nn.Parameter(torch.log(A).reshape(self.n_mamba_heads, self.mamba_head_dim, -1)) + self.D = nn.Parameter(torch.ones(self.n_mamba_heads, self.mamba_head_dim)) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config" + ) + + def cuda_kernels_forward( + self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None + ): + batch_size, seq_len, _ = hidden_states.shape + use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 + + # 1. Gated linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + hidden_states, gate = projected_states.view(batch_size, -1, 2, seq_len).chunk(2, dim=2) + hidden_states = hidden_states.squeeze(2).contiguous() + gate = gate.squeeze(2) + gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if use_precomputed_states: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if attention_mask is not None and not torch.all(attention_mask == 1): + hidden_states = hidden_states * attention_mask.unsqueeze(1) + if cache_params is not None: + conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) + if attention_mask is not None and not torch.all(attention_mask == 1): + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. SSM sequence transformation + # 3.a. input varying initialization of time_step, B and C + + hidden_states = hidden_states.reshape(-1, self.n_mamba_heads, self.mamba_head_dim, seq_len).transpose(0, 1) + ssm_parameters = (self.x_proj_weight[:, None, :, :] @ hidden_states).transpose(-1, -2) + + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + discrete_time_step = self.dt_proj_weight[:, None] @ time_step.transpose(-1, -2) + + A = -torch.exp(self.A_log.float()) + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = self.dt_proj_bias.float() if self.dt_proj_bias is not None else None + scan_outputs = torch.empty((batch_size, 0, seq_len), device=hidden_states.device, dtype=hidden_states.dtype) + + if use_precomputed_states: + for n in range(self.n_mamba_heads): + scan_outputs_ = selective_state_update( + cache_params.ssm_states[self.layer_idx][:, n], + hidden_states[n, ..., 0], + discrete_time_step[n, ..., 0], + A[n], + B[n, :, 0], + C[n, :, 0], + self.D[n], + gate[n, ..., 0], + time_proj_bias[n], + dt_softplus=True, + ).unsqueeze(-1) + scan_outputs = torch.cat((scan_outputs, scan_outputs_), dim=1) + + else: + ssm_state = torch.empty( + (batch_size, 0, self.mamba_head_dim, self.ssm_state_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + for n in range(self.n_mamba_heads): + scan_outputs_, ssm_state_ = selective_scan_fn( + hidden_states[n], + discrete_time_step[n], + A[n], + B[n].transpose(1, 2), + C[n].transpose(1, 2), + self.D[n].float(), + gate[n], + time_proj_bias[n], + delta_softplus=True, + return_last_state=True, + ) + scan_outputs = torch.cat((scan_outputs, scan_outputs_), dim=1).contiguous() + ssm_state = torch.cat((ssm_state, ssm_state_.unsqueeze(1)), dim=1) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + return contextualized_states + + def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated linear projection + projected_states = self.in_proj(input_states).transpose(1, 2) + + hidden_states, gate = projected_states.view(batch_size, -1, 2, seq_len).chunk(2, dim=2) + hidden_states = hidden_states.squeeze(2).contiguous() + gate = gate.squeeze(2) + gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1) + + use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache) + # 2. Convolution sequence transformation + if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size: + if self.training: + # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + else: + ssm_state = cache_params.ssm_states[self.layer_idx] + + ssm_state = ssm_state.to(hidden_states.device) + + if ( + cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] == batch_size + ): + conv_state = cache_params.conv_states[self.layer_idx] + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + conv_state[:, :, -1] = hidden_states[:, :, 0] + cache_params.conv_states[self.layer_idx] = conv_state + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) + else: + if attention_mask is not None and not torch.all(attention_mask == 1): + hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1] :].unsqueeze(1) + conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + cache_params.conv_states[self.layer_idx] = conv_state + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + if attention_mask is not None and not torch.all(attention_mask == 1): + hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1] :].unsqueeze(1) + else: + ssm_state = torch.zeros( + (batch_size, self.n_mamba_heads, self.mamba_head_dim, self.ssm_state_size), + device=hidden_states.device, + dtype=dtype, + ) + if attention_mask is not None and not torch.all(attention_mask == 1): + hidden_states = hidden_states * attention_mask.unsqueeze(1) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + if attention_mask is not None and not torch.all(attention_mask == 1): + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + hidden_states = hidden_states.reshape(-1, self.n_mamba_heads, self.mamba_head_dim, seq_len).transpose(0, 1) + ssm_parameters = (self.x_proj_weight[:, None, :, :] @ hidden_states).transpose(-1, -2) + + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + discrete_time_step = (self.dt_proj_weight[:, None] @ time_step.transpose(-1, -2)) + self.dt_proj_bias[ + :, None, :, None + ] + + discrete_time_step = nn.functional.softplus(discrete_time_step) + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + A = -torch.exp(self.A_log.float()) + discrete_A = torch.exp(A[:, None, :, None, :] * discrete_time_step[:, :, :, :, None]) + discrete_B = discrete_time_step[:, :, :, :, None] * B[:, :, None, :, :].float() + deltaB_u = discrete_B * hidden_states[:, :, :, :, None].float() + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + scan_outputs = [] + for i in range(seq_len): + ssm_state = discrete_A[:, :, :, i, :].transpose(0, 1) * ssm_state + deltaB_u[:, :, :, i, :].transpose(0, 1) + scan_output = torch.matmul(ssm_state.transpose(0, 1).to(dtype), C[:, :, i, :].unsqueeze(-1)) + scan_outputs.append(scan_output[:, :, :, 0]) + scan_output = torch.stack(scan_outputs, dim=-1) + scan_output = scan_output + (hidden_states * self.D[:, None, :, None]) + scan_output = scan_output * self.act(gate) + + if use_cache: + cache_params.ssm_states[self.layer_idx] = ssm_state + + # 4. Final linear projection + contextualized_states = self.out_proj( + scan_output.transpose(0, 1).reshape(batch_size, -1, seq_len).transpose(1, 2) + ) + return contextualized_states + + def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None): + if self.use_fast_kernels: + if not is_fast_path_available or "cuda" not in self.x_proj_weight.device.type: + raise ValueError( + "Fast Mamba kernels are not available. Make sure to they are installed and that " + "the mamba module is on a CUDA device. lease run 'pip install causal-conv1d>=1.2.0' " + "and 'pip install mamba-ssm', or set use_mamba_kernels=False in the model's config." + ) + return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask=attention_mask) + return self.slow_forward(hidden_states, cache_params, attention_mask=attention_mask) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Zamba +class ZambaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class ZambaAttentionDecoderLayer(nn.Module): + def __init__(self, config: ZambaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.self_attn = ZAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.feed_forward = ZambaMLP(config) + self.input_layernorm = ZambaRMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = ZambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)` + original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`. + This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The + concatenated tensor is then used as input of the pre-attention RMSNorm + (see fig. 2 in https://arxiv.org/pdf/2405.16712). + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + layer_idx=layer_idx, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + # feed-forward (MLP) + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class ZambaMambaDecoderLayer(nn.Module): + def __init__(self, config: ZambaConfig, layer_idx: int): + super().__init__() + self.mamba = ZambaMambaMixer(config=config, layer_idx=layer_idx) + self.input_layernorm = ZambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx + + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: Optional[torch.Tensor] = None, + layer_idx: int = None, + attention_mask: Optional[torch.Tensor] = None, + causal_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + transformer_hidden_states: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + + residual = hidden_states + + # `transformer_hidden_states` is the output from shared transformer + linear layer (see fig. 2 in https://arxiv.org/pdf/2405.16712). + # `transformer_hidden_states` is then added to the input to the mamba layer below (as described in eq. (6) of https://arxiv.org/pdf/2405.16712). + hidden_states = ( + hidden_states + transformer_hidden_states if transformer_hidden_states is not None else hidden_states + ) + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.mamba( + hidden_states=hidden_states, + cache_params=past_key_value, + attention_mask=attention_mask, + ) + + self_attn_weights = None + + # residual connection after mamba + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + return outputs + + +class HybridLayer(nn.Module): + def __init__(self, shared_transf: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer): + super().__init__() + self.shared_transf = shared_transf + self.linear = linear + self.mamba_decoder = mamba + + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: Optional[torch.Tensor] = None, + layer_idx: int = None, + attention_mask: Optional[torch.Tensor] = None, + causal_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + original_hidden_states (`torch.FloatTensor`): word embedding output that will be concatenated with + hidden activations to form the input of the shared transformer layer. + layer_idx (`int`): layer number. + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + + layer_outputs = self.shared_transf( + hidden_states, + original_hidden_states=original_hidden_states, + layer_idx=layer_idx, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + transformer_hidden_states = layer_outputs[0] + + if output_attentions: + self_attn_weights = layer_outputs[1] + + transformer_hidden_states = self.linear(transformer_hidden_states) + + layer_outputs = self.mamba_decoder( + hidden_states, + transformer_hidden_states=transformer_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + if output_attentions: + layer_outputs = (layer_outputs[0], self_attn_weights) + layer_outputs[2:] + + return layer_outputs + + +ZAMBA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ZambaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Zamba Model outputting raw hidden-states without any specific head on top.", + ZAMBA_START_DOCSTRING, +) +class ZambaPreTrainedModel(PreTrainedModel): + config_class = ZambaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ZambaAttentionDecoderLayer", "ZambaMambaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = False + _supports_sdpa = False + _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + _is_stateful = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, ZambaMambaMixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + module.x_proj_weight.data.normal_(mean=0.0, std=std) + dt_init_std = self.config.mamba_dt_rank**-0.5 + nn.init.uniform_(module.dt_proj_weight, -dt_init_std, dt_init_std) + + mamba_head_dim = self.config.mamba_expand * self.config.hidden_size // self.config.n_mamba_heads + dt = torch.exp( + torch.rand(self.config.n_mamba_heads, mamba_head_dim) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + + with torch.no_grad(): + module.dt_proj_bias.copy_(inv_dt) + module.dt_proj_bias._no_reinit = True + + @classmethod + @classmethod + def _check_and_enable_flash_attn_2( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + hard_check_only: bool = False, + check_device_map: bool = False, + ): + """ + Overloads `PreTrainedModel._check_and_enable_flash_attn_2` so as to DISABLE Flash Attention 2 by default on Zamba models. + Flash attention 2 is currently not supported in the HuggingFace implementation of Zamba v1. + """ + config = super()._check_and_enable_flash_attn_2( + config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map + ) + + # if using the default path -> swap sdpa by eager + if not hard_check_only and config._attn_implementation == "flash_attention_2": + config._attn_implementation = "eager" + + return config + + +ZAMBA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. + Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and + `(batch_size, d_inner, d_state)` respectively. + See the `HybridMambaAttentionDynamicCache` class for more details. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Zamba Model outputting raw hidden-states without any specific head on top.", + ZAMBA_START_DOCSTRING, +) +class ZambaModel(ZambaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ZambaDecoderLayer`] + + Args: + config: ZambaConfig + """ + + def __init__(self, config: ZambaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + block = ZambaAttentionDecoderLayer(config) + mamba_layers = [] + linear_layers = [] + self.layers_block_type = config.layers_block_type + for i in range(config.num_hidden_layers): + if config.layers_block_type[i] == "mamba": + mamba_layers.append(ZambaMambaDecoderLayer(config, layer_idx=i)) + elif config.layers_block_type[i] == "hybrid": + linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)) + mamba_layers.append(ZambaMambaDecoderLayer(config, layer_idx=i)) + mamba_layers = iter(mamba_layers) + linear_layers = iter(linear_layers) + layers = [] + self._tied_weights_keys = [] + for layer_id, layer_type in enumerate(self.layers_block_type): + if layer_type == "hybrid": + prefix_name = f"layers.{layer_id}." + tied_keys = [ + "shared_transf.self_attn.q_proj.weight", + "shared_transf.self_attn.k_proj.weight", + "shared_transf.self_attn.v_proj.weight", + "shared_transf.self_attn.o_proj.weight", + "shared_transf.feed_forward.gate_proj.weight", + "shared_transf.feed_forward.up_proj.weight", + "shared_transf.feed_forward.down_proj.weight", + "shared_transf.input_layernorm.weight", + "shared_transf.pre_ff_layernorm.weight", + ] + self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + layers.append(HybridLayer(block, next(linear_layers), next(mamba_layers))) + else: + layers.append(next(mamba_layers)) + self.layers = nn.ModuleList(layers) + + self._attn_implementation = config._attn_implementation + self.final_layernorm = ZambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(ZAMBA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + original_hidden_states = torch.clone(inputs_embeds) + # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer + + if use_cache and past_key_values is None: + logger.warning_once( + "Zamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for layer_idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + original_hidden_states, + layer_idx, + attention_mask, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = layer( + hidden_states, + original_hidden_states=original_hidden_states, + layer_idx=layer_idx, + attention_mask=attention_mask, + causal_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None if not use_cache else past_key_values + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +# Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba, JAMBA->ZAMBA +class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin): + def __init__(self, config: ZambaConfig): + super().__init__(config) + self.model = ZambaModel(config) + self._tied_weights_keys = ["lm_head.weight", *self.model._tied_weights_keys] + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(ZAMBA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, ZambaForCausalLM + + >>> model = ZambaForCausalLM.from_pretrained("Zyphra/Zamba-7B-v1") + >>> tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba-7B-v1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + empty_past_kv = past_key_values is None + + # Omit tokens covered by past_key_values + if not empty_past_kv: + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], dtype=self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "num_logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + +@add_start_docstrings( + """ + The Zamba Model with a sequence classification head on top (linear layer). + + [`ZambaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + ZAMBA_START_DOCSTRING, +) +class ZambaForSequenceClassification(ZambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = ZambaModel(config) + self._tied_weights_keys = self.model._tied_weights_keys + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(ZAMBA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 2e2b0726f6..ea0bbc1701 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -9975,6 +9975,34 @@ class YosoPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) +class ZambaForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ZambaForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ZambaModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ZambaPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ZoeDepthForDepthEstimation(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 54ea0e23b3..29b31dab50 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2100,7 +2100,17 @@ class GenerationTesterMixin: # 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1" # 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the # standard cache format (e.g.gptbigcode ) - models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba", "xlnet") + models_without_standard_cache = ( + "ctrl", + "fsmt", + "gptbigcode", + "mega", + "reformer", + "jamba", + "mamba", + "xlnet", + "zamba", + ) has_standard_cache = not any( model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache ) diff --git a/tests/models/zamba/__init__.py b/tests/models/zamba/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py new file mode 100644 index 0000000000..c0a8020bed --- /dev/null +++ b/tests/models/zamba/test_modeling_zamba.py @@ -0,0 +1,736 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Zamba model.""" + +import math +import tempfile +import unittest + +import pytest +from parameterized import parameterized + +from transformers import AutoTokenizer, ZambaConfig, is_torch_available +from transformers.testing_utils import ( + require_bitsandbytes, + require_flash_attn, + require_torch, + require_torch_gpu, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + ZambaForCausalLM, + ZambaForSequenceClassification, + ZambaModel, + ) + from transformers.models.zamba.modeling_zamba import ( + HybridMambaAttentionDynamicCache, + ) + + +class ZambaModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_labels=True, + vocab_size=99, + hidden_size=64, + mamba_dt_rank=32, + num_hidden_layers=5, + attn_layer_offset=1, + attn_layer_period=8, + num_attention_heads=4, + num_key_value_heads=4, + n_mamba_heads=2, + intermediate_size=37, + hidden_act="gelu", + hidden_mamba_act="silu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.mamba_dt_rank = mamba_dt_rank + self.num_hidden_layers = num_hidden_layers + self.attn_layer_offset = attn_layer_offset + self.attn_layer_period = attn_layer_period + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.n_mamba_heads = n_mamba_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_mamba_act = hidden_mamba_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return ZambaConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + mamba_dt_rank=self.mamba_dt_rank, + num_hidden_layers=self.num_hidden_layers, + attn_layer_offset=self.attn_layer_offset, + attn_layer_period=self.attn_layer_period, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + n_mamba_heads=self.n_mamba_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_mamba_act=self.hidden_mamba_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=True, + initializer_range=self.initializer_range, + use_mamba_kernels=False, + ) + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + + return ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) + + def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): + model = ZambaModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = ZambaForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids, labels=token_labels) + result = model(input_ids) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + config.add_cross_attention = True + model = ZambaForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + # Attention: Zamba needs the cache to be initialized to return a cache! + past_key_values = HybridMambaAttentionDynamicCache( + config, input_ids.shape[0], model.dtype, device=model.device + ) + outputs = model( + input_ids, + attention_mask=input_mask, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + cache_position=torch.arange( + input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device + ), + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_for_sequence_classification( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = ZambaForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + ZambaModel, + ZambaForCausalLM, + ZambaForSequenceClassification, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (ZambaForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": ZambaModel, + "text-classification": ZambaForSequenceClassification, + "text-generation": ZambaForCausalLM, + "zero-shot": ZambaForSequenceClassification, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + + def setUp(self): + self.model_tester = ZambaModelTester(self) + self.config_tester = ConfigTester(self, config_class=ZambaConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_casual_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_initialization(self): + r""" + Overriding the test_initialization test as the A_log and D params of the Mamba block are initialized differently + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + if "A_log" in name: + A = torch.arange(1, config.mamba_d_state + 1, dtype=torch.float32)[None, :] + self.assertTrue(torch.allclose(param.data, torch.log(A), atol=1e-5, rtol=1e-5)) + elif "D" in name: + # check if it's a ones like + self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5)) + elif "x_proj" in name or "dt_proj_weight" in name: + self.assertIn( + ((param.data.mean() * 1e2).round() / 1e2).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized (raw value {param.data.mean()})", + ) + elif "dt_proj_bias" in name: + dt = torch.exp( + torch.tensor([0, 1]) * (math.log(config.time_step_max) - math.log(config.time_step_min)) + + math.log(config.time_step_min) + ).clamp(min=config.time_step_floor) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + if param.requires_grad: + self.assertTrue(param.data.max().item() <= inv_dt[1]) + self.assertTrue(param.data.min().item() >= inv_dt[0]) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_mismatched_shapes_have_properly_initialized_weights(self): + r""" + Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the + Mamba block are initialized differently and we tested that in test_initialization + """ + self.skipTest("Cumbersome and redundant for Zamba") + + def test_attention_outputs(self): + r""" + Overriding the test_attention_outputs test as the Zamba model outputs attention only for its attention layers + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + + expected_num_attentions = ( + math.ceil( + (self.model_tester.num_hidden_layers - self.model_tester.attn_layer_offset) + / self.model_tester.attn_layer_period + ) + + 1 + ) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), expected_num_attentions) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), expected_num_attentions) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.attentions + + self.assertEqual(len(self_attentions), expected_num_attentions) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + def _get_input_ids_and_config(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + return config, input_ids, input_mask + + def test_left_padding_compatibility(self): + r""" + Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences + effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value. + """ + import inspect + # NOTE: left-padding results in small numerical differences. This is expected. + # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 + + # First, filter out models that don't support left padding - generative and decoder-only. + # Zamba is a decoder-only architecture + decoder_only_classes = self.all_generative_model_classes + + # Then, test left-padding + def _prepare_model_kwargs(input_ids, attention_mask, signature): + model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} + if "position_ids" in signature: + position_ids = torch.cumsum(attention_mask, dim=-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + model_kwargs["position_ids"] = position_ids + if "cache_position" in signature: + cache_position = torch.arange(input_ids.shape[-1], device=torch_device) + model_kwargs["cache_position"] = cache_position + return model_kwargs + + for model_class in decoder_only_classes: + config, input_ids, attention_mask = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() + signature = inspect.signature(model.forward).parameters.keys() + + # Without padding + model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) + next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] + + # With left-padding (length 32) + pad_size = (input_ids.shape[0], 32) + padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id + padded_input_ids = torch.cat((padding, input_ids), dim=1) + padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) + model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) + next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] + + # They should result in very similar logits + self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3)) + + @require_flash_attn + @require_torch_gpu + @require_bitsandbytes + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_fp32_ln(self): + r""" + Overriding the test_flash_attn_2_fp32_ln test as the Zamba model, like Mixtral, doesn't support + right padding + use cache with FA2 + """ + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_input = inputs_dict[model.main_input_name] + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: Zamba does not support right padding + use_cache with FA2. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + load_in_4bit=True, + ) + + for _, param in model.named_parameters(): + # upcast only layer norms + if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): + param.data = param.data.to(torch.float32) + + _ = model(dummy_input) + # with attention mask + _ = model(dummy_input, attention_mask=dummy_attention_mask) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_generate_padding_right(self): + r""" + Overriding the test_flash_attn_2_generate_padding_right test as the Zamba model, like Mixtral, doesn't support + right padding + use cache with FA2 + """ + import torch + + for model_class in self.all_generative_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + + model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + with self.assertRaises(ValueError): + _ = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_generate_use_cache(self): + r""" + Overriding the test_flash_attn_2_generate_use_cache test as the Zamba model, like Mixtral, doesn't support + right padding + use cache with FA2 + """ + import torch + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: Zamba does not support right padding + use_cache with FA2. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + r""" + Overriding the test_flash_attn_2_inference_padding_right test as the Zamba model, like Mixtral, doesn't support + right padding + use cache with FA2 + """ + self.skipTest(reason="Zamba flash attention does not support right padding") + + @unittest.skip(reason="Zamba has its own special cache type") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + + +@require_torch +class ZambaModelIntegrationTest(unittest.TestCase): + model = None + tokenizer = None + + @classmethod + @slow + def setUpClass(cls): + model_id = "Zyphra/Zamba-7B-v1" + cls.model = ZambaForCausalLM.from_pretrained( + model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, use_mamba_kernels=False + ) + cls.tokenizer = AutoTokenizer.from_pretrained(model_id) + + @slow + def test_simple_generate(self): + self.model.to(torch_device) + + input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[ + "input_ids" + ].to(torch_device) + out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10) + output_sentence = self.tokenizer.decode(out[0, :]) + self.assertEqual( + output_sentence, + " Hey how are you doing on this lovely evening? I hope you are all doing well. I am", + ) + + with torch.no_grad(): + logits = self.model(input_ids=input_ids).logits + + EXPECTED_LOGITS_NO_GRAD = torch.tensor( + [ + -7.9375, 8.1875, 1.3984, -6.0000, -7.9375, -7.9375, -7.9375, -7.9375, + -7.9375, -7.9375, -7.9375, -7.9375, 2.7500, 13.0625, -7.9375, -7.9375, + -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, + -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, + -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375 + ] + , dtype=torch.float32) # fmt: skip + + torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3) + + @slow + def test_simple_batched_generate_with_padding(self): + self.model.to(torch_device) + self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + self.model.resize_token_embeddings(len(self.tokenizer)) + + inputs = self.tokenizer( + ["Hey how are you doing on this lovely evening?", "Tell me a story"], padding=True, return_tensors="pt" + ).to(torch_device) + out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10) + output_sentences = self.tokenizer.batch_decode(out) + self.assertEqual( + output_sentences[0], + " Hey how are you doing on this lovely evening? I hope you are all doing well. I am", + ) + self.assertEqual( + output_sentences[1], + "[PAD][PAD][PAD][PAD][PAD][PAD] Tell me a story about a time when you were in a difficult situation", + ) + + with torch.no_grad(): + logits = self.model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits + + EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor( + [ + -7.9375, 8.1250, 1.3594, -6.0000, -7.9375, -7.9375, -7.9375, -7.9375, + -7.9375, -7.9375, -7.9375, -7.9375, 2.7344, 13.0625, -7.9375, -7.9375, + -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, + -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, + -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375, -7.9375 + ] + , dtype=torch.float32) # fmt: skip + + EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor( + [ + -6.3750, 3.4219, 0.6719, -5.0312, -8.5000, -8.5000, -8.5000, -8.5000, + -8.5000, -8.5000, -8.5000, -8.5000, 2.0625, 10.3750, -8.5000, -8.5000, + -8.5000, -8.5000, -8.5000, -8.5000, -8.5000, -8.5000, -8.5000, -8.5000, + -8.5000, -8.5000, -8.5000, -8.5000, -8.5000, -8.5000, -8.5000, -8.5000, + -8.5000, -8.5000, -8.5000, -8.5000, -8.5000, -8.5000, -8.5000, -8.5000 + ] + , dtype=torch.float32) # fmt: skip + + torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 7cd8523ccd..7bd9379636 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -132,6 +132,11 @@ SPECIAL_CASES_TO_ALLOW = { "t2u_variance_predictor_hidden_dim", "t2u_variance_predictor_kernel_size", ], + "ZambaConfig": [ + "tie_word_embeddings", + "attn_layer_offset", + "attn_layer_period", + ], "MllamaTextConfig": [ "initializer_range", ], diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index cd87d09ec8..9eb43e7b90 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -907,6 +907,8 @@ src/transformers/models/xmod/convert_xmod_original_pytorch_checkpoint_to_pytorch src/transformers/models/yolos/convert_yolos_to_pytorch.py src/transformers/models/yoso/convert_yoso_pytorch_to_pytorch.py src/transformers/models/yoso/modeling_yoso.py +src/transformers/models/zamba/configuration_zamba.py +src/transformers/models/zamba/modeling_zamba.py src/transformers/onnx/__main__.py src/transformers/onnx/config.py src/transformers/onnx/convert.py