From 1aa7256f01ce771220daaaf36af33b9f59447e5c Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com> Date: Mon, 21 Jul 2025 14:59:36 +0200 Subject: [PATCH] Refactor `MambaCache` to `modeling_mamba.py` (#38086) * Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162c7c5c5ce8a73ccf48cffc8a96343ebb33. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354fcab630de5899321f5daced5fb312c2a2. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf160ee251536f07c40b8a0857d499e70db6. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba --- docs/source/en/model_doc/falcon_mamba.md | 7 + docs/source/en/model_doc/mamba.md | 7 + src/transformers/__init__.py | 1 - src/transformers/cache_utils.py | 221 +++---- .../generation/configuration_utils.py | 3 - src/transformers/generation/utils.py | 9 +- .../configuration_falcon_mamba.py | 27 +- .../falcon_mamba/modeling_falcon_mamba.py | 232 +++++--- .../falcon_mamba/modular_falcon_mamba.py | 540 ++++++++++++++++++ .../models/mamba/modeling_mamba.py | 156 ++++- .../models/mamba2/modeling_mamba2.py | 41 +- src/transformers/utils/dummy_pt_objects.py | 7 - .../test_modeling_falcon_mamba.py | 43 +- tests/models/mamba/test_modeling_mamba.py | 37 +- tests/utils/test_cache_utils.py | 2 - utils/modular_model_converter.py | 7 +- 16 files changed, 1033 insertions(+), 307 deletions(-) create mode 100644 src/transformers/models/falcon_mamba/modular_falcon_mamba.py diff --git a/docs/source/en/model_doc/falcon_mamba.md b/docs/source/en/model_doc/falcon_mamba.md index a8d7886894..0b797c7c78 100644 --- a/docs/source/en/model_doc/falcon_mamba.md +++ b/docs/source/en/model_doc/falcon_mamba.md @@ -110,6 +110,13 @@ outputs = model.generate(**inputs, max_new_tokens=100) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` +## FalconMambaCache + +[[autodoc]] FalconMambaCache + - update_conv_state + - update_ssm_state + - reset + ## FalconMambaConfig [[autodoc]] FalconMambaConfig diff --git a/docs/source/en/model_doc/mamba.md b/docs/source/en/model_doc/mamba.md index 1e30e9af8b..06efa75971 100644 --- a/docs/source/en/model_doc/mamba.md +++ b/docs/source/en/model_doc/mamba.md @@ -116,6 +116,13 @@ print(tokenizer.decode(output[0], skip_special_tokens=True)) trainer.train() ``` +## MambaCache + +[[autodoc]] MambaCache + - update_conv_state + - update_ssm_state + - reset + ## MambaConfig [[autodoc]] MambaConfig diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index aeb448ab1e..3d1566580a 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -371,7 +371,6 @@ else: "EncoderDecoderCache", "HQQQuantizedCache", "HybridCache", - "MambaCache", "OffloadedCache", "OffloadedStaticCache", "QuantizedCache", diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0b7052c634..eecf0c7c0e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -2,6 +2,7 @@ import copy import importlib.metadata import json import os +import warnings from collections.abc import Iterable from dataclasses import dataclass from typing import Any, Optional, Union @@ -18,6 +19,7 @@ from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_great if is_hqq_available(): from hqq.core.quantize import Quantizer as HQQQuantizer + logger = logging.get_logger(__name__) @@ -2091,106 +2093,6 @@ class OffloadedHybridCache(HybridChunkedCache): self.device_value_cache[self.active_device_layer].fill_(0.0) -class MambaCache: - """ - Cache for mamba model which does not have attention mechanism and key value states. - - Arguments: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. - dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default `dtype` to use when initializing the layer. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. Should be the same as the layer. - - Example: - - ```python - >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache - - >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") - - >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values - MambaCache() - ``` - """ - - is_compileable = True - - # TODO (joao): add layer_device_map arg and update code in `generate` accordingly - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - dtype: torch.dtype = torch.float16, - device: Union[torch.device, str, None] = None, - ): - self.max_batch_size = max_batch_size - self._dtype = dtype - self.intermediate_size = config.intermediate_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - - self.conv_states: list[torch.Tensor] = [] - self.ssm_states: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - for _ in range(config.num_hidden_layers): - conv_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.conv_kernel_size, - device=device, - dtype=self._dtype, - ) - ssm_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.ssm_state_size, - device=device, - dtype=self._dtype, - ) - - torch._dynamo.mark_static_address(conv_state) - torch._dynamo.mark_static_address(ssm_state) - self.conv_states.append(conv_state) - self.ssm_states.append(ssm_state) - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor - ) -> torch.Tensor: - # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used - # when the cache is initialized in the forward pass (e.g. Mamba) - if self.conv_states[layer_idx].device != new_conv_state.device: - self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) - - conv_state = self.conv_states[layer_idx] - cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) - - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) - self.conv_states[layer_idx].zero_() - self.conv_states[layer_idx] += conv_state - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) - return self.ssm_states[layer_idx] - - def reset(self): - for layer_idx in range(len(self.conv_states)): - # In-place ops prevent breaking the static address - self.conv_states[layer_idx].zero_() - self.ssm_states[layer_idx].zero_() - - class OffloadedStaticCache(StaticCache): """ Static cache class to be used with `torch.compile(model)` that offloads to the CPU or @@ -2461,3 +2363,122 @@ class OffloadedStaticCache(StaticCache): self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True) self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True) + + +# TODO (manuel, joao): remove this class, it is here only for backwards compatibility +# PEP 562: Lazy loading for deprecated location of MambaCache +def __getattr__(name: str) -> Any: + if name == "MambaCache": + warnings.warn( + ( + "Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed " + "in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead." + ), + FutureWarning, + stacklevel=2, + ) + + class MambaCache: + """ + Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed + in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead. + + Cache for mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + MambaCache() + ``` + """ + + is_compileable = True + + # TODO (joao): add layer_device_map arg and update code in `generate` accordingly + def __init__( + self, + config, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: Union[torch.device, str, None] = None, + ): + self.max_batch_size = max_batch_size + self._dtype = dtype + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: list[torch.Tensor] = [] + self.ssm_states: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=self._dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=self._dtype, + ) + + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used + # when the cache is initialized in the forward pass (e.g. Mamba) + if self.conv_states[layer_idx].device != new_conv_state.device: + self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) + + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + def reset(self): + for layer_idx in range(len(self.conv_states)): + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() + + return MambaCache + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 7658765937..165252927c 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -54,7 +54,6 @@ if is_torch_available(): HQQQuantizedCache, HybridCache, HybridChunkedCache, - MambaCache, OffloadedHybridCache, OffloadedStaticCache, QuantizedCacheConfig, @@ -75,7 +74,6 @@ if is_torch_available(): "hybrid_chunked": HybridChunkedCache, "offloaded_hybrid": OffloadedHybridCache, "offloaded_hybrid_chunked": OffloadedHybridCache, - "mamba": MambaCache, } QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} ALL_CACHE_IMPLEMENTATIONS = ( @@ -186,7 +184,6 @@ class GenerationConfig(PushToHubMixin): - `"offloaded_static"`: [`OffloadedStaticCache`] - `"sliding_window"`: [`SlidingWindowCache`] - `"hybrid"`: [`HybridCache`] - - `"mamba"`: [`MambaCache`] - `"quantized"`: [`QuantizedCache`] If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 809db679d3..76b3d7bd8a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1916,9 +1916,8 @@ class GenerationMixin(ContinuousMixin): or isinstance( cache_to_check, (HybridChunkedCache, OffloadedHybridCache) ) # due to internal slicing, we always re-init + or cache_to_check.max_cache_len < max_cache_len ) - if cache_implementation != "mamba": - need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len if requires_cross_attention_cache and hasattr(self, "_cache"): need_new_cache = ( @@ -1957,9 +1956,9 @@ class GenerationMixin(ContinuousMixin): def _supports_default_dynamic_cache(cls) -> bool: """ Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. - This adds exception for some models like `Jamba` model which uses its own `HybridMambaAttentionDynamicCache` + This adds exception for some models like `Mamba` models which use their own caches and do not need to initialize the Cache in advance in order to save memory (because no back and forth - `to_legacy_cache` and `from_legacy_cache` will be performed for `HybridMambaAttentionDynamicCache`). + `to_legacy_cache` and `from_legacy_cache` will be performed for mamba-based models). """ # NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name return not cls._is_stateful and all( @@ -2016,7 +2015,7 @@ class GenerationMixin(ContinuousMixin): if generation_config.use_cache is False: return - # Quick escape route 3: model that only supports legacy caches = nothing to prepare + # Quick escape route 3: model that only supports legacy caches or models that supply it in `prepare_inputs_for_generation` (mamba, zamba, ...) if not self._supports_default_dynamic_cache(): if generation_config.cache_implementation is not None: warnings.warn( diff --git a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py index 4099920f40..86a0e9ad22 100644 --- a/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/configuration_falcon_mamba.py @@ -1,5 +1,11 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_falcon_mamba.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. +# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,15 +18,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""FALCONMAMBA configuration""" import math from ...configuration_utils import PretrainedConfig -from ...utils import logging - - -logger = logging.get_logger(__name__) class FalconMambaConfig(PretrainedConfig): @@ -79,10 +80,13 @@ class FalconMambaConfig(PretrainedConfig): Whether or not to rescale `out_proj` weights when initializing. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. - use_mambapy (`bool`, *optional*, defaults to `False`): - Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not available. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited. + use_falcon_mambapy (`bool`, *optional*, defaults to `False`): + This argument corresponds to `use_mambapy` in MambaConfig. + Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not available. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited. mixer_rms_eps (`float`, *optional*, defaults to 1e-06): The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states. + + Example: ```python @@ -125,10 +129,11 @@ class FalconMambaConfig(PretrainedConfig): time_step_floor=1e-4, rescale_prenorm_residual=False, use_cache=True, - use_mambapy=False, + use_falcon_mambapy=False, mixer_rms_eps=1e-6, **kwargs, ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.state_size = state_size @@ -153,10 +158,8 @@ class FalconMambaConfig(PretrainedConfig): self.rescale_prenorm_residual = rescale_prenorm_residual self.residual_in_fp32 = residual_in_fp32 self.use_cache = use_cache - self.use_mambapy = use_mambapy + self.use_falcon_mambapy = use_falcon_mambapy self.mixer_rms_eps = mixer_rms_eps - super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) - __all__ = ["FalconMambaConfig"] diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index c32157ab11..56a5770ba7 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_falcon_mamba.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. # @@ -12,29 +18,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch FALCONMAMBA model.""" import math from dataclasses import dataclass from typing import Any, Optional, Union import torch -import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import MambaCache +from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_mamba_ssm_available, + is_mambapy_available, +) from .configuration_falcon_mamba import FalconMambaConfig -logger = logging.get_logger(__name__) - if is_mambapy_available(): from mambapy.pscan import pscan else: @@ -53,9 +59,109 @@ if is_causal_conv1d_available(): else: causal_conv1d_update, causal_conv1d_fn = None, None -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) + +logger = logging.get_logger(__name__) + + +class FalconMambaCache: + """ + Cache for falcon_mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache + + >>> model = FalconMambaForCausalLM.from_pretrained("state-spaces/falcon_mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/falcon_mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + FalconMambaCache() + ``` + """ + + is_compileable = True + + # TODO (joao): add layer_device_map arg and update code in `generate` accordingly + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: Union[torch.device, str, None] = None, + ): + self.max_batch_size = max_batch_size + self._dtype = dtype + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: list[torch.Tensor] = [] + self.ssm_states: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=self._dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=self._dtype, + ) + + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used + # when the cache is initialized in the forward pass (e.g. FalconMamba) + if self.conv_states[layer_idx].device != new_conv_state.device: + self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) + + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx].zero_() + self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + def reset(self): + for layer_idx in range(len(self.conv_states)): + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() def rms_forward(hidden_states, variance_epsilon=1e-6): @@ -107,7 +213,7 @@ class FalconMambaMixer(nn.Module): self.activation = config.hidden_act self.act = ACT2FN[config.hidden_act] - self.use_mambapy = config.use_mambapy + self.use_falcon_mambapy = config.use_falcon_mambapy # projection of the input hidden states self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias) @@ -126,6 +232,7 @@ class FalconMambaMixer(nn.Module): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias + self.warn_slow_implementation() # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here self.register_buffer( "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False @@ -135,8 +242,12 @@ class FalconMambaMixer(nn.Module): ) self.rms_eps = config.mixer_rms_eps + def warn_slow_implementation(self): + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) if not is_fast_path_available: - if self.use_mambapy: + if self.use_falcon_mambapy: if is_mambapy_available(): logger.warning_once( "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" @@ -157,7 +268,7 @@ class FalconMambaMixer(nn.Module): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): @@ -269,10 +380,10 @@ class FalconMambaMixer(nn.Module): contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) return contextualized_states - def slow_forward( - self, + # fmt: off + def slow_forward(self, input_states, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): @@ -344,7 +455,7 @@ class FalconMambaMixer(nn.Module): deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) - if self.use_mambapy and self.training and cache_params is None: + if self.use_falcon_mambapy and self.training and cache_params is None: hs = pscan( discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2) ) # [batch, seq_len, intermediate_size, ssm_state_size] @@ -371,21 +482,23 @@ class FalconMambaMixer(nn.Module): # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] return contextualized_states + # fmt: on - # Copied from transformers.models.mamba.modeling_mamba.MambaMixer.forward def forward( self, hidden_states, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling(): return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) -# Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba class FalconMambaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -395,17 +508,15 @@ class FalconMambaRMSNorm(nn.Module): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def extra_repr(self): - return f"{self.weight.shape[0]}, eps={self.variance_epsilon}" - - # Ignore copy def forward(self, hidden_states): return self.weight.to(hidden_states.device) * rms_forward( hidden_states, variance_epsilon=self.variance_epsilon ) + def extra_repr(self): + return f"{self.weight.shape[0]}, eps={self.variance_epsilon}" + -# Copied from transformers.models.mamba.modeling_mamba.MambaBlock with Mamba->FalconMamba,FalconMambaCache->MambaCache class FalconMambaBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() @@ -418,7 +529,7 @@ class FalconMambaBlock(GradientCheckpointingLayer): def forward( self, hidden_states, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): @@ -435,7 +546,6 @@ class FalconMambaBlock(GradientCheckpointingLayer): @auto_docstring -# Copied from transformers.models.mamba.modeling_mamba.MambaPreTrainedModel with Mamba->FalconMamba class FalconMambaPreTrainedModel(PreTrainedModel): config: FalconMambaConfig base_model_prefix = "backbone" @@ -507,13 +617,12 @@ class FalconMambaPreTrainedModel(PreTrainedModel): @dataclass @auto_docstring( custom_intro=""" - Class for the FALCONMAMBA model outputs. + Class for the FALCON_MAMBA model outputs. """ ) -# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->FALCONMAMBA,Mamba->FalconMamba,FalconMambaCache->MambaCache class FalconMambaOutput(ModelOutput): r""" - cache_params (`MambaCache`): + cache_params (`FalconMambaCache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -521,7 +630,7 @@ class FalconMambaOutput(ModelOutput): """ last_hidden_state: Optional[torch.FloatTensor] = None - cache_params: Optional[MambaCache] = None + cache_params: Optional[FalconMambaCache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None @@ -531,14 +640,13 @@ class FalconMambaOutput(ModelOutput): Base class for causal language model (or autoregressive) outputs. """ ) -# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->FalconMamba,FalconMambaCache->MambaCache class FalconMambaCausalLMOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - cache_params (`MambaCache`): + cache_params (`FalconMambaCache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -547,7 +655,7 @@ class FalconMambaCausalLMOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None - cache_params: Optional[MambaCache] = None + cache_params: Optional[FalconMambaCache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None @@ -577,7 +685,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel): self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -585,7 +693,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel): attention_mask: Optional[torch.LongTensor] = None, ) -> Union[tuple, FalconMambaOutput]: r""" - cache_params (`MambaCache`, *optional*): + cache_params (`FalconMambaCache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): @@ -608,7 +716,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel): if use_cache: if cache_params is None: - cache_params = MambaCache( + cache_params = FalconMambaCache( self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype ) cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) @@ -623,6 +731,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel): ) else: cache_params = None + hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None for mixer_block in self.layers: @@ -653,11 +762,10 @@ class FalconMambaModel(FalconMambaPreTrainedModel): @auto_docstring( custom_intro=""" - The FALCONMAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input + The FALCON_MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). """ ) -# Copied from transformers.models.mamba.modeling_mamba.MambaForCausalLM with MAMBA->FALCONMAMBA,Mamba->FalconMamba,mamba->falcon_mamba,FalconMambaCache->MambaCache class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] @@ -704,38 +812,32 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin): input_ids, inputs_embeds=None, use_cache=None, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, **kwargs, ): # Overwritten -- uses `cache_params` as opposed to `past_key_values` - - if use_cache: - # `cache_position` should have been initialized in `generate` - if cache_position is None: - raise ValueError( - "`cache_position` should not be None as it should have been initialized in " - "`model.generate`, you are responsible for passing in a valid `cache_position` if " - "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" - ) - if cache_position[0] > 0: - input_ids = input_ids[:, -1].unsqueeze(-1) - - if attention_mask is not None: - attention_mask = None - + model_inputs = {"input_ids": input_ids.contiguous()} + if use_cache and cache_params is None: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device) + if inputs_embeds is not None: + model_inputs = {"inputs_embeds": inputs_embeds} + max_batch_size = inputs_embeds.size(0) else: - # we initialize the `cache_position` to full size of `conv_states` at prefill stage - # considering padding will be applied when input length is shorter, and truncation - # will be applied when it is longer, so it will be equivalent to always have it match - # the length of `cache_params.conv_states`, which is `config.conv_kernel` - cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) + max_batch_size = input_ids.size(0) + cache_params = FalconMambaCache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype) - if inputs_embeds is not None and cache_params is None: + if use_cache and cache_position[0] > 0: + model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous() + attention_mask = None + + if not use_cache and inputs_embeds is not None: model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids.contiguous()} model_inputs.update( { @@ -753,7 +855,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - cache_params: Optional[MambaCache] = None, + cache_params: Optional[FalconMambaCache] = None, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -762,7 +864,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin): **kwargs, # for now we need this for generation ) -> Union[tuple, FalconMambaCausalLMOutput]: r""" - cache_params (`MambaCache`, *optional*): + cache_params (`FalconMambaCache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -811,4 +913,4 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin): ) -__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel"] +__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel", "FalconMambaCache"] diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py new file mode 100644 index 0000000000..fdd1da3e2f --- /dev/null +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -0,0 +1,540 @@ +# coding=utf-8 +# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch FALCONMAMBA model.""" + +from typing import Optional + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...utils import auto_docstring, logging +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available +from ..mamba.configuration_mamba import MambaConfig +from ..mamba.modeling_mamba import ( + MambaBlock, + MambaCache, + MambaCausalLMOutput, + MambaForCausalLM, + MambaMixer, + MambaModel, + MambaOutput, + MambaPreTrainedModel, + MambaRMSNorm, +) + + +logger = logging.get_logger(__name__) + +if is_mambapy_available(): + from mambapy.pscan import pscan +else: + pscan = None + +if is_mamba_ssm_available(): + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + + from ...kernels.falcon_mamba import mamba_inner_fn +else: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + + +class FalconMambaConfig(MambaConfig): + """ + This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the FALCON_MAMBA + [tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50280): + Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`FalconMambaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + state_size (`int`, *optional*, defaults to 16): shape of the state space latents. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the model. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 0): + The id of the end of sentence token in the vocabulary. + expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. + conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. + use_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.1): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + residual_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model + time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + time_step_scale (`float`, *optional*, defaults to 1.0): + Scale used used to scale `dt_proj.bias`. + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_init_scheme (`float`, *optional*, defaults to `"random"`): + Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]` + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + rescale_prenorm_residual (`bool`, *optional*, defaults to `False`): + Whether or not to rescale `out_proj` weights when initializing. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the cache should be used. + use_falcon_mambapy (`bool`, *optional*, defaults to `False`): + This argument corresponds to `use_mambapy` in MambaConfig. + Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not available. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited. + mixer_rms_eps (`float`, *optional*, defaults to 1e-06): + The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states. + + + Example: + + ```python + >>> from transformers import FalconMambaConfig, FalconMambaModel + + >>> # Initializing a FalconMamba configuration + >>> configuration = FalconMambaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = FalconMambaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + def __init__( + self, + vocab_size=50280, + hidden_size=768, + state_size=16, + num_hidden_layers=32, + layer_norm_epsilon=1e-5, + pad_token_id=0, + bos_token_id=0, + eos_token_id=0, + expand=2, + conv_kernel=4, + use_bias=False, + use_conv_bias=True, + hidden_act="silu", + initializer_range=0.1, + residual_in_fp32=True, + time_step_rank="auto", + time_step_scale=1.0, + time_step_min=0.001, + time_step_max=0.1, + time_step_init_scheme="random", + time_step_floor=1e-4, + rescale_prenorm_residual=False, + use_cache=True, + use_falcon_mambapy=False, + mixer_rms_eps=1e-6, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + state_size=state_size, + num_hidden_layers=num_hidden_layers, + layer_norm_epsilon=layer_norm_epsilon, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + expand=expand, + conv_kernel=conv_kernel, + use_bias=use_bias, + use_conv_bias=use_conv_bias, + hidden_act=hidden_act, + initializer_range=initializer_range, + residual_in_fp32=residual_in_fp32, + time_step_rank=time_step_rank, + time_step_scale=time_step_scale, + time_step_min=time_step_min, + time_step_max=time_step_max, + time_step_init_scheme=time_step_init_scheme, + time_step_floor=time_step_floor, + rescale_prenorm_residual=rescale_prenorm_residual, + use_cache=use_cache, + use_falcon_mambapy=use_falcon_mambapy, + **kwargs, + ) + self.mixer_rms_eps = mixer_rms_eps + + +class FalconMambaCache(MambaCache): + pass + + +def rms_forward(hidden_states, variance_epsilon=1e-6): + """ + Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will + leverage this in order to multiply the final result with the RMSNorm weight + + Args: + hidden_states (`torch.Tensor`): + Hidden states to normalize + variance_epsilon (`float`): + The eps value to add in the square root scaling factor + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) + return hidden_states.to(input_dtype) + + +class FalconMambaMixer(MambaMixer): + def warn_slow_implementation(self): + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) + if not is_fast_path_available: + if self.use_falcon_mambapy: + if is_mambapy_available(): + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + raise ImportError( + "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py." + ) + else: + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py." + ) + + def __init__(self, config: FalconMambaConfig, layer_idx: int): + super().__init__(config, layer_idx) + # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here + self.register_buffer( + "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False + ) + self.register_buffer( + "dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False + ) + self.rms_eps = config.mixer_rms_eps + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + if self.training and cache_params is None: # Doesn't support outputting the states -> used for training + contextualized_states = mamba_inner_fn( + projected_states, + self.conv1d.weight, + self.conv1d.bias if self.use_conv_bias else None, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias.float() if self.use_bias else None, + -torch.exp(self.A_log.float()), + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + b_rms_weight=self.b_c_rms, + c_rms_weight=self.b_c_rms, + dt_rms_weight=self.dt_rms, + b_c_dt_rms_eps=self.rms_eps, + ) + + else: + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if cache_params is not None and cache_position[0] > 0: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.update_conv_state(self.layer_idx, conv_states, cache_position) + hidden_states = causal_conv1d_fn( + hidden_states, conv_weights, self.conv1d.bias, activation=self.activation + ) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + B = rms_forward(B, variance_epsilon=self.rms_eps) + C = rms_forward(C, variance_epsilon=self.rms_eps) + time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) + + # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module + # at the price of a small overhead. + if hasattr(self.config, "_pre_quantization_dtype"): + discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2) + else: + discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) + + A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None + if cache_params is not None and cache_position[0] > 0: + scan_outputs = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(self.layer_idx, ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + return contextualized_states + + def slow_forward( + self, + input_states, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + # use `cache_position.shape[0]` to check whether we are in prefill + # stage, it's equivalent to check `cache_position[0] == 0`, which + # breaks dynamo fullgraph constraints + if cache_position is not None and cache_position.shape[0] == self.conv_kernel_size: + conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + + cache_params.update_conv_state(self.layer_idx, conv_state, cache_position) + hidden_states = self.act( + self.conv1d(hidden_states)[..., :seq_len] + ) # [batch, intermediate_size, seq_len] + else: + conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position) + conv_state = conv_state.to(self.conv1d.weight.device) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = ( + self.act(hidden_states).to(dtype).unsqueeze(-1) + ) # [batch, intermediate_size, 1] : decoding + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + + B = rms_forward(B, variance_epsilon=self.rms_eps) + C = rms_forward(C, variance_epsilon=self.rms_eps) + time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) + + discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose( + 1, 2 + ) # [batch, intermediate_size, seq_len] + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] + discrete_A = torch.exp( + A[None, :, None, :] * discrete_time_step[:, :, :, None] + ) # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_B = ( + discrete_time_step[:, :, :, None] * B[:, None, :, :].float() + ) # [batch, intermediate_size, seq_len, ssm_state_size] + deltaB_u = discrete_B * hidden_states[:, :, :, None].float() + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + if self.use_falcon_mambapy and self.training and cache_params is None: + hs = pscan( + discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2) + ) # [batch, seq_len, intermediate_size, ssm_state_size] + scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] + scan_output = scan_output + hidden_states * self.D[None, :, None] + scan_output = scan_output * self.act(gate) + else: + scan_outputs = [] + for i in range(seq_len): + ssm_state = ( + discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] + ) # [batch, intermediate_size, ssm_state] + scan_output = torch.matmul( + ssm_state.to(dtype), C[:, i, :].unsqueeze(-1) + ) # [batch, intermediate_size, 1] + scan_outputs.append(scan_output[:, :, 0]) + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len] + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = scan_output * self.act(gate) + + if cache_params is not None: + cache_params.update_ssm_state(self.layer_idx, ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + return contextualized_states + + def forward( + self, + hidden_states, + cache_params: Optional[FalconMambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) + if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling(): + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class FalconMambaRMSNorm(MambaRMSNorm): + def forward(self, hidden_states): + return self.weight.to(hidden_states.device) * rms_forward( + hidden_states, variance_epsilon=self.variance_epsilon + ) + + +class FalconMambaBlock(MambaBlock): + pass + + +@auto_docstring +class FalconMambaPreTrainedModel(MambaPreTrainedModel): + pass + + +class FalconMambaOutput(MambaOutput): + pass + + +class FalconMambaCausalLMOutput(MambaCausalLMOutput): + pass + + +class FalconMambaModel(MambaModel, FalconMambaPreTrainedModel): + def __init__(self, config): + FalconMambaPreTrainedModel.__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [FalconMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)] + ) + + self.gradient_checkpointing = False + self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + raise AttributeError("Not needed for FalconMamba") + + +class FalconMambaForCausalLM(MambaForCausalLM): + pass + + +__all__ = [ + "FalconMambaForCausalLM", + "FalconMambaModel", + "FalconMambaPreTrainedModel", + "FalconMambaCache", + "FalconMambaConfig", +] diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index fd0e362e9f..06d87b4d5c 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -24,7 +24,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import MambaCache +from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel @@ -55,9 +55,106 @@ if is_causal_conv1d_available(): else: causal_conv1d_update, causal_conv1d_fn = None, None -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) + +class MambaCache: + """ + Cache for mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + MambaCache() + ``` + """ + + is_compileable = True + + # TODO (joao): add layer_device_map arg and update code in `generate` accordingly + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: Union[torch.device, str, None] = None, + ): + self.max_batch_size = max_batch_size + self._dtype = dtype + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: list[torch.Tensor] = [] + self.ssm_states: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=self._dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=self._dtype, + ) + + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used + # when the cache is initialized in the forward pass (e.g. Mamba) + if self.conv_states[layer_idx].device != new_conv_state.device: + self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) + + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx].zero_() + self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + def reset(self): + for layer_idx in range(len(self.conv_states)): + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() class MambaMixer(nn.Module): @@ -109,6 +206,12 @@ class MambaMixer(nn.Module): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias + self.warn_slow_implementation() + + def warn_slow_implementation(self): + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) if not is_fast_path_available: if self.use_mambapy: if is_mambapy_available(): @@ -319,6 +422,9 @@ class MambaMixer(nn.Module): cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): + is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) + ) if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling(): return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) @@ -650,32 +756,26 @@ class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin): **kwargs, ): # Overwritten -- uses `cache_params` as opposed to `past_key_values` - - if use_cache: - # `cache_position` should have been initialized in `generate` - if cache_position is None: - raise ValueError( - "`cache_position` should not be None as it should have been initialized in " - "`model.generate`, you are responsible for passing in a valid `cache_position` if " - "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" - ) - if cache_position[0] > 0: - input_ids = input_ids[:, -1].unsqueeze(-1) - - if attention_mask is not None: - attention_mask = None - + model_inputs = {"input_ids": input_ids.contiguous()} + if use_cache and cache_params is None: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device) + if inputs_embeds is not None: + model_inputs = {"inputs_embeds": inputs_embeds} + max_batch_size = inputs_embeds.size(0) else: - # we initialize the `cache_position` to full size of `conv_states` at prefill stage - # considering padding will be applied when input length is shorter, and truncation - # will be applied when it is longer, so it will be equivalent to always have it match - # the length of `cache_params.conv_states`, which is `config.conv_kernel` - cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) + max_batch_size = input_ids.size(0) + cache_params = MambaCache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype) - if inputs_embeds is not None and cache_params is None: + if use_cache and cache_position[0] > 0: + model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous() + attention_mask = None + + if not use_cache and inputs_embeds is not None: model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids.contiguous()} model_inputs.update( { @@ -751,4 +851,4 @@ class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin): ) -__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel"] +__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel", "MambaCache"] diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 2511c1809d..5a83186fb0 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -970,38 +970,33 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin): **kwargs, ): # Overwritten -- uses `cache_params` as opposed to `past_key_values` - - if use_cache: - # `cache_position` should have been initialized in `generate` - if cache_position is None: - raise ValueError( - "`cache_position` should not be None as it should have been initialized in " - "`model.generate`, you are responsible for passing in a valid `cache_position` if " - "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" - ) - if cache_position[0] > 0: - input_ids = input_ids[:, -1][..., None] - - if attention_mask is not None: - attention_mask = None + model_inputs = {"input_ids": input_ids.contiguous()} + if use_cache and cache_params is None: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device) + if inputs_embeds is not None: + model_inputs = {"inputs_embeds": inputs_embeds} + max_batch_size = inputs_embeds.size(0) else: - # we initialize the `cache_position` to full size of `conv_states` at prefill stage - # considering padding will be applied when input length is shorter, and truncation - # will be applied when it is longer, so it will be equivalent to always have it match - # the length of `cache_params.conv_states`, which is `config.conv_kernel` - cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) + max_batch_size = input_ids.size(0) + cache_params = Mamba2Cache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype) - if inputs_embeds is not None and cache_params is None: + if use_cache and cache_position[0] > 0: + model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous() + attention_mask = None + + if not use_cache and inputs_embeds is not None: model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} model_inputs.update( { - "attention_mask": attention_mask, "cache_params": cache_params, "use_cache": use_cache, "cache_position": cache_position, + "attention_mask": attention_mask, } ) return model_inputs diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e865237485..7026bf1697 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -44,13 +44,6 @@ class HybridCache(metaclass=DummyObject): requires_backends(self, ["torch"]) -class MambaCache(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - class OffloadedCache(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index cada419ea0..855d0f2103 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -26,7 +26,6 @@ from transformers.testing_utils import ( require_torch_accelerator, require_torch_large_accelerator, require_torch_multi_accelerator, - require_torch_multi_gpu, slow, torch_device, ) @@ -41,10 +40,10 @@ if is_torch_available(): import torch from transformers import ( + FalconMambaCache, FalconMambaForCausalLM, FalconMambaModel, ) - from transformers.cache_utils import MambaCache # Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba @@ -312,31 +311,6 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest def test_config(self): self.config_tester.run_common_tests() - @require_torch_multi_gpu - def test_multi_gpu_data_parallel_forward(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - # some params shouldn't be scattered by nn.DataParallel - # so just remove them if they are present. - blacklist_non_batched_params = ["cache_params"] - for k in blacklist_non_batched_params: - inputs_dict.pop(k, None) - - # move input tensors to cuda:O - for k, v in inputs_dict.items(): - if torch.is_tensor(v): - inputs_dict[k] = v.to(0) - - for model_class in self.all_model_classes: - model = model_class(config=config) - model.to(0) - model.eval() - - # Wrap model in nn.DataParallel - model = torch.nn.DataParallel(model) - with torch.no_grad(): - _ = model(**self._prepare_for_class(inputs_dict, model_class)) - def test_falcon_mamba_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_falcon_mamba_model(*config_and_inputs) @@ -411,7 +385,7 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, MambaCache): # MODIFIED PART START + if isinstance(tuple_object, FalconMambaCache): # MODIFIED PART START recursive_check(tuple_object.conv_states, dict_object.conv_states) recursive_check(tuple_object.ssm_states, dict_object.ssm_states) elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END @@ -458,6 +432,10 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + @unittest.skip("Mamba models do not support DDP.") + def test_multi_gpu_data_parallel_forward(self): + pass + @require_torch @require_torch_accelerator @@ -497,7 +475,9 @@ class FalconMambaIntegrationTests(unittest.TestCase): @require_bitsandbytes def test_generation_4bit(self): quantization_config = BitsAndBytesConfig(load_in_4bit=True) - model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config) + model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config).to( + torch_device + ) inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device) out = model.generate(**inputs, max_new_tokens=20, do_sample=False) @@ -513,6 +493,7 @@ class FalconMambaIntegrationTests(unittest.TestCase): inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device) out = model.generate(**inputs, max_new_tokens=20, do_sample=False) + print(self.tokenizer.batch_decode(out, skip_special_tokens=False)[0]) self.assertEqual( self.tokenizer.batch_decode(out, skip_special_tokens=False)[0], @@ -543,7 +524,7 @@ class FalconMambaIntegrationTests(unittest.TestCase): inputs = tok(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(torch_device) model = AutoModelForCausalLM.from_pretrained(model_id, device_map=0, torch_dtype=torch.float16) - out = model.generate(**inputs, max_new_tokens=20) + out = model.generate(**inputs, max_new_tokens=20, do_sample=False) out = tok.batch_decode(out, skip_special_tokens=True) self.assertListEqual(out, EXPECTED_OUTPUT) @@ -553,7 +534,7 @@ class FalconMambaIntegrationTests(unittest.TestCase): inputs_embeds = model.get_input_embeddings()(inputs.pop("input_ids")) inputs["inputs_embeds"] = inputs_embeds - out = model.generate(**inputs, max_new_tokens=20) + out = model.generate(**inputs, max_new_tokens=20, do_sample=False) out = tok.batch_decode(out, skip_special_tokens=True) EXPECTED_OUTPUTS = Expectations( diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index b570d1a130..e99c8b1e57 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -20,7 +20,7 @@ from unittest.util import safe_repr from parameterized import parameterized from transformers import AutoTokenizer, MambaConfig, is_torch_available -from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device +from transformers.testing_utils import require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -32,10 +32,10 @@ if is_torch_available(): import torch from transformers import ( + MambaCache, MambaForCausalLM, MambaModel, ) - from transformers.models.mamba.modeling_mamba import MambaCache class MambaModelTester: @@ -279,31 +279,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi def test_config(self): self.config_tester.run_common_tests() - @require_torch_multi_gpu - def test_multi_gpu_data_parallel_forward(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - # some params shouldn't be scattered by nn.DataParallel - # so just remove them if they are present. - blacklist_non_batched_params = ["cache_params"] - for k in blacklist_non_batched_params: - inputs_dict.pop(k, None) - - # move input tensors to cuda:O - for k, v in inputs_dict.items(): - if torch.is_tensor(v): - inputs_dict[k] = v.to(0) - - for model_class in self.all_model_classes: - model = model_class(config=config) - model.to(0) - model.eval() - - # Wrap model in nn.DataParallel - model = torch.nn.DataParallel(model) - with torch.no_grad(): - _ = model(**self._prepare_for_class(inputs_dict, model_class)) - def test_mamba_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_mamba_model(*config_and_inputs) @@ -452,6 +427,10 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size), ) + @unittest.skip("Mamba models do not support DDP.") + def test_multi_gpu_data_parallel_forward(self): + pass + @require_torch class MambaIntegrationTests(unittest.TestCase): @@ -547,11 +526,11 @@ class MambaIntegrationTests(unittest.TestCase): torch_device ) - output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba") + output = model.generate(input_ids, max_new_tokens=20) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") - output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba") + output = model.generate(input_ids, max_new_tokens=20) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index bb73052b25..b1998b7cfe 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -62,8 +62,6 @@ if is_torch_available(): TEST_CACHE_IMPLEMENTATIONS = [ cache_name for cache_name in ALL_CACHE_IMPLEMENTATIONS - # TODO (joao): Mamba is not compatible with most models, remove from `ALL_CACHE_IMPLEMENTATIONS`? - if cache_name != "mamba" # TODO (joao): offloaded_hybrid == offloaded_hybrid_chunked, deprecate one of them if cache_name != "offloaded_hybrid" ] diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 1f32716c8f..3b447dc7bf 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -141,7 +141,12 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer): return updated_node def leave_ImportFrom(self, original_node, updated_node): - """The imports from other file types (configuration, processing etc) should use original model name.""" + """ + The imports from other file types (configuration, processing etc) should use original model name. + Also, no replaces on absolute imports (e.g. `from mamba_ssm import ...`) + """ + if len(original_node.relative) == 0: # no replaces on absolute imports + return original_node if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()): patterns = "|".join(ALL_FILE_TYPES) regex = rf"({patterns})_{self.new_name}"