Refactor MambaCache to modeling_mamba.py (#38086)

* Refactor MambaCache to modeling_mamba.py (parity with Zamba)

* ruff

* fix dummies

* update

* update

* remove mamba ref in cache tests

* remove cache_implementation from tests

* update

* ruff

* ruff

* sneaky regression

* model consistency

* fix test_multi_gpu_data_parallel_forward

* fix falcon slow tests

* ruff

* ruff

* add sample false

* try to fix slow tests

* Revert "fix test_multi_gpu_data_parallel_forward"

This reverts commit 66b7162c7c5c5ce8a73ccf48cffc8a96343ebb33.

* fix tests on nvidia t4, remove dataparallel tests from mamba

* ruff

* remove DDP tests from mamba and falcon_mamba

* add explicit error for MambaCache

* mamba2 also needs to init cache in prepare_inputs_for_generation

* ruff

* ruff

* move MambaCache to its own file

* ruff

* unprotected import fix

* another attempt to fix unprotected imports

* Revert "another attempt to fix unprotected imports"

This reverts commit 2338354fcab630de5899321f5daced5fb312c2a2.

* fixing unprotected import, attempt 3

* Update src/transformers/cache_utils.py

* ruff's fault

* fix arthur review

* modular falcon mamba

* found a hack

* fix config docs

* fix docs

* add export info

* merge modular falcon branch

* oopsie

* fix fast path failing

* new approach

* oopsie

* fix types

* Revert new pragma in modular

This reverts commit 80b1cf160ee251536f07c40b8a0857d499e70db6.

* trying another modular workaround

* review & fix ci

* oopsie

* clear prepare_inputs on mamba/mamba2/falcon_mamba
This commit is contained in:
Manuel de Prada Corral
2025-07-21 14:59:36 +02:00
committed by GitHub
parent a419a40234
commit 1aa7256f01
16 changed files with 1033 additions and 307 deletions

View File

@@ -110,6 +110,13 @@ outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
## FalconMambaCache
[[autodoc]] FalconMambaCache
- update_conv_state
- update_ssm_state
- reset
## FalconMambaConfig
[[autodoc]] FalconMambaConfig

View File

@@ -116,6 +116,13 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
trainer.train()
```
## MambaCache
[[autodoc]] MambaCache
- update_conv_state
- update_ssm_state
- reset
## MambaConfig
[[autodoc]] MambaConfig

View File

@@ -371,7 +371,6 @@ else:
"EncoderDecoderCache",
"HQQQuantizedCache",
"HybridCache",
"MambaCache",
"OffloadedCache",
"OffloadedStaticCache",
"QuantizedCache",

View File

@@ -2,6 +2,7 @@ import copy
import importlib.metadata
import json
import os
import warnings
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Optional, Union
@@ -18,6 +19,7 @@ from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_great
if is_hqq_available():
from hqq.core.quantize import Quantizer as HQQQuantizer
logger = logging.get_logger(__name__)
@@ -2091,106 +2093,6 @@ class OffloadedHybridCache(HybridChunkedCache):
self.device_value_cache[self.active_device_layer].fill_(0.0)
class MambaCache:
"""
Cache for mamba model which does not have attention mechanism and key value states.
Arguments:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used.
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The default `dtype` to use when initializing the layer.
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
Example:
```python
>>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
>>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values
MambaCache()
```
"""
is_compileable = True
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
dtype: torch.dtype = torch.float16,
device: Union[torch.device, str, None] = None,
):
self.max_batch_size = max_batch_size
self._dtype = dtype
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.conv_states: list[torch.Tensor] = []
self.ssm_states: list[torch.Tensor] = []
device = torch.device(device) if device is not None else None
for _ in range(config.num_hidden_layers):
conv_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
dtype=self._dtype,
)
ssm_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
dtype=self._dtype,
)
torch._dynamo.mark_static_address(conv_state)
torch._dynamo.mark_static_address(ssm_state)
self.conv_states.append(conv_state)
self.ssm_states.append(ssm_state)
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
# This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
# when the cache is initialized in the forward pass (e.g. Mamba)
if self.conv_states[layer_idx].device != new_conv_state.device:
self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
return self.ssm_states[layer_idx]
def reset(self):
for layer_idx in range(len(self.conv_states)):
# In-place ops prevent breaking the static address
self.conv_states[layer_idx].zero_()
self.ssm_states[layer_idx].zero_()
class OffloadedStaticCache(StaticCache):
"""
Static cache class to be used with `torch.compile(model)` that offloads to the CPU or
@@ -2461,3 +2363,122 @@ class OffloadedStaticCache(StaticCache):
self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True)
self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True)
# TODO (manuel, joao): remove this class, it is here only for backwards compatibility
# PEP 562: Lazy loading for deprecated location of MambaCache
def __getattr__(name: str) -> Any:
if name == "MambaCache":
warnings.warn(
(
"Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed "
"in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead."
),
FutureWarning,
stacklevel=2,
)
class MambaCache:
"""
Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed
in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead.
Cache for mamba model which does not have attention mechanism and key value states.
Arguments:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used.
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The default `dtype` to use when initializing the layer.
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
Example:
```python
>>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
>>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values
MambaCache()
```
"""
is_compileable = True
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
def __init__(
self,
config,
max_batch_size: int,
dtype: torch.dtype = torch.float16,
device: Union[torch.device, str, None] = None,
):
self.max_batch_size = max_batch_size
self._dtype = dtype
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.conv_states: list[torch.Tensor] = []
self.ssm_states: list[torch.Tensor] = []
device = torch.device(device) if device is not None else None
for _ in range(config.num_hidden_layers):
conv_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
dtype=self._dtype,
)
ssm_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
dtype=self._dtype,
)
torch._dynamo.mark_static_address(conv_state)
torch._dynamo.mark_static_address(ssm_state)
self.conv_states.append(conv_state)
self.ssm_states.append(ssm_state)
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
# This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
# when the cache is initialized in the forward pass (e.g. Mamba)
if self.conv_states[layer_idx].device != new_conv_state.device:
self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
return self.ssm_states[layer_idx]
def reset(self):
for layer_idx in range(len(self.conv_states)):
# In-place ops prevent breaking the static address
self.conv_states[layer_idx].zero_()
self.ssm_states[layer_idx].zero_()
return MambaCache
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -54,7 +54,6 @@ if is_torch_available():
HQQQuantizedCache,
HybridCache,
HybridChunkedCache,
MambaCache,
OffloadedHybridCache,
OffloadedStaticCache,
QuantizedCacheConfig,
@@ -75,7 +74,6 @@ if is_torch_available():
"hybrid_chunked": HybridChunkedCache,
"offloaded_hybrid": OffloadedHybridCache,
"offloaded_hybrid_chunked": OffloadedHybridCache,
"mamba": MambaCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
ALL_CACHE_IMPLEMENTATIONS = (
@@ -186,7 +184,6 @@ class GenerationConfig(PushToHubMixin):
- `"offloaded_static"`: [`OffloadedStaticCache`]
- `"sliding_window"`: [`SlidingWindowCache`]
- `"hybrid"`: [`HybridCache`]
- `"mamba"`: [`MambaCache`]
- `"quantized"`: [`QuantizedCache`]
If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See

View File

@@ -1916,9 +1916,8 @@ class GenerationMixin(ContinuousMixin):
or isinstance(
cache_to_check, (HybridChunkedCache, OffloadedHybridCache)
) # due to internal slicing, we always re-init
or cache_to_check.max_cache_len < max_cache_len
)
if cache_implementation != "mamba":
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
if requires_cross_attention_cache and hasattr(self, "_cache"):
need_new_cache = (
@@ -1957,9 +1956,9 @@ class GenerationMixin(ContinuousMixin):
def _supports_default_dynamic_cache(cls) -> bool:
"""
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
This adds exception for some models like `Jamba` model which uses its own `HybridMambaAttentionDynamicCache`
This adds exception for some models like `Mamba` models which use their own caches
and do not need to initialize the Cache in advance in order to save memory (because no back and forth
`to_legacy_cache` and `from_legacy_cache` will be performed for `HybridMambaAttentionDynamicCache`).
`to_legacy_cache` and `from_legacy_cache` will be performed for mamba-based models).
"""
# NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name
return not cls._is_stateful and all(
@@ -2016,7 +2015,7 @@ class GenerationMixin(ContinuousMixin):
if generation_config.use_cache is False:
return
# Quick escape route 3: model that only supports legacy caches = nothing to prepare
# Quick escape route 3: model that only supports legacy caches or models that supply it in `prepare_inputs_for_generation` (mamba, zamba, ...)
if not self._supports_default_dynamic_cache():
if generation_config.cache_implementation is not None:
warnings.warn(

View File

@@ -1,5 +1,11 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_falcon_mamba.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,15 +18,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FALCONMAMBA configuration"""
import math
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class FalconMambaConfig(PretrainedConfig):
@@ -79,10 +80,13 @@ class FalconMambaConfig(PretrainedConfig):
Whether or not to rescale `out_proj` weights when initializing.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the cache should be used.
use_mambapy (`bool`, *optional*, defaults to `False`):
Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not available. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
use_falcon_mambapy (`bool`, *optional*, defaults to `False`):
This argument corresponds to `use_mambapy` in MambaConfig.
Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not available. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
Example:
```python
@@ -125,10 +129,11 @@ class FalconMambaConfig(PretrainedConfig):
time_step_floor=1e-4,
rescale_prenorm_residual=False,
use_cache=True,
use_mambapy=False,
use_falcon_mambapy=False,
mixer_rms_eps=1e-6,
**kwargs,
):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.state_size = state_size
@@ -153,10 +158,8 @@ class FalconMambaConfig(PretrainedConfig):
self.rescale_prenorm_residual = rescale_prenorm_residual
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
self.use_mambapy = use_mambapy
self.use_falcon_mambapy = use_falcon_mambapy
self.mixer_rms_eps = mixer_rms_eps
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
__all__ = ["FalconMambaConfig"]

View File

@@ -1,3 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_falcon_mamba.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
#
@@ -12,29 +18,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch FALCONMAMBA model."""
import math
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import MambaCache
from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
from ...utils.import_utils import (
is_causal_conv1d_available,
is_mamba_ssm_available,
is_mambapy_available,
)
from .configuration_falcon_mamba import FalconMambaConfig
logger = logging.get_logger(__name__)
if is_mambapy_available():
from mambapy.pscan import pscan
else:
@@ -53,9 +59,109 @@ if is_causal_conv1d_available():
else:
causal_conv1d_update, causal_conv1d_fn = None, None
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
logger = logging.get_logger(__name__)
class FalconMambaCache:
"""
Cache for falcon_mamba model which does not have attention mechanism and key value states.
Arguments:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used.
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The default `dtype` to use when initializing the layer.
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
Example:
```python
>>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache
>>> model = FalconMambaForCausalLM.from_pretrained("state-spaces/falcon_mamba-130m-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/falcon_mamba-130m-hf")
>>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values
FalconMambaCache()
```
"""
is_compileable = True
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
dtype: torch.dtype = torch.float16,
device: Union[torch.device, str, None] = None,
):
self.max_batch_size = max_batch_size
self._dtype = dtype
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.conv_states: list[torch.Tensor] = []
self.ssm_states: list[torch.Tensor] = []
device = torch.device(device) if device is not None else None
for _ in range(config.num_hidden_layers):
conv_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
dtype=self._dtype,
)
ssm_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
dtype=self._dtype,
)
torch._dynamo.mark_static_address(conv_state)
torch._dynamo.mark_static_address(ssm_state)
self.conv_states.append(conv_state)
self.ssm_states.append(ssm_state)
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
# This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
# when the cache is initialized in the forward pass (e.g. FalconMamba)
if self.conv_states[layer_idx].device != new_conv_state.device:
self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
self.ssm_states[layer_idx].zero_()
self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device)
return self.ssm_states[layer_idx]
def reset(self):
for layer_idx in range(len(self.conv_states)):
# In-place ops prevent breaking the static address
self.conv_states[layer_idx].zero_()
self.ssm_states[layer_idx].zero_()
def rms_forward(hidden_states, variance_epsilon=1e-6):
@@ -107,7 +213,7 @@ class FalconMambaMixer(nn.Module):
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.use_mambapy = config.use_mambapy
self.use_falcon_mambapy = config.use_falcon_mambapy
# projection of the input hidden states
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
@@ -126,6 +232,7 @@ class FalconMambaMixer(nn.Module):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias
self.warn_slow_implementation()
# Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
self.register_buffer(
"b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
@@ -135,8 +242,12 @@ class FalconMambaMixer(nn.Module):
)
self.rms_eps = config.mixer_rms_eps
def warn_slow_implementation(self):
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
if not is_fast_path_available:
if self.use_mambapy:
if self.use_falcon_mambapy:
if is_mambapy_available():
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
@@ -157,7 +268,7 @@ class FalconMambaMixer(nn.Module):
def cuda_kernels_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[MambaCache] = None,
cache_params: Optional[FalconMambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
@@ -269,10 +380,10 @@ class FalconMambaMixer(nn.Module):
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
return contextualized_states
def slow_forward(
self,
# fmt: off
def slow_forward(self,
input_states,
cache_params: Optional[MambaCache] = None,
cache_params: Optional[FalconMambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
@@ -344,7 +455,7 @@ class FalconMambaMixer(nn.Module):
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
if self.use_mambapy and self.training and cache_params is None:
if self.use_falcon_mambapy and self.training and cache_params is None:
hs = pscan(
discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)
) # [batch, seq_len, intermediate_size, ssm_state_size]
@@ -371,21 +482,23 @@ class FalconMambaMixer(nn.Module):
# 4. Final linear projection
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
return contextualized_states
# fmt: on
# Copied from transformers.models.mamba.modeling_mamba.MambaMixer.forward
def forward(
self,
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_params: Optional[FalconMambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
# Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba
class FalconMambaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
@@ -395,17 +508,15 @@ class FalconMambaRMSNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def extra_repr(self):
return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"
# Ignore copy
def forward(self, hidden_states):
return self.weight.to(hidden_states.device) * rms_forward(
hidden_states, variance_epsilon=self.variance_epsilon
)
def extra_repr(self):
return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"
# Copied from transformers.models.mamba.modeling_mamba.MambaBlock with Mamba->FalconMamba,FalconMambaCache->MambaCache
class FalconMambaBlock(GradientCheckpointingLayer):
def __init__(self, config, layer_idx):
super().__init__()
@@ -418,7 +529,7 @@ class FalconMambaBlock(GradientCheckpointingLayer):
def forward(
self,
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_params: Optional[FalconMambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
@@ -435,7 +546,6 @@ class FalconMambaBlock(GradientCheckpointingLayer):
@auto_docstring
# Copied from transformers.models.mamba.modeling_mamba.MambaPreTrainedModel with Mamba->FalconMamba
class FalconMambaPreTrainedModel(PreTrainedModel):
config: FalconMambaConfig
base_model_prefix = "backbone"
@@ -507,13 +617,12 @@ class FalconMambaPreTrainedModel(PreTrainedModel):
@dataclass
@auto_docstring(
custom_intro="""
Class for the FALCONMAMBA model outputs.
Class for the FALCON_MAMBA model outputs.
"""
)
# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->FALCONMAMBA,Mamba->FalconMamba,FalconMambaCache->MambaCache
class FalconMambaOutput(ModelOutput):
r"""
cache_params (`MambaCache`):
cache_params (`FalconMambaCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
@@ -521,7 +630,7 @@ class FalconMambaOutput(ModelOutput):
"""
last_hidden_state: Optional[torch.FloatTensor] = None
cache_params: Optional[MambaCache] = None
cache_params: Optional[FalconMambaCache] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
@@ -531,14 +640,13 @@ class FalconMambaOutput(ModelOutput):
Base class for causal language model (or autoregressive) outputs.
"""
)
# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->FalconMamba,FalconMambaCache->MambaCache
class FalconMambaCausalLMOutput(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
cache_params (`MambaCache`):
cache_params (`FalconMambaCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
@@ -547,7 +655,7 @@ class FalconMambaCausalLMOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
cache_params: Optional[MambaCache] = None
cache_params: Optional[FalconMambaCache] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
@@ -577,7 +685,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[MambaCache] = None,
cache_params: Optional[FalconMambaCache] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@@ -585,7 +693,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
attention_mask: Optional[torch.LongTensor] = None,
) -> Union[tuple, FalconMambaOutput]:
r"""
cache_params (`MambaCache`, *optional*):
cache_params (`FalconMambaCache`, *optional*):
If passed along, the model uses the previous state in all the blocks (which will give the output for the
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
use_cache (`bool`, *optional*):
@@ -608,7 +716,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
if use_cache:
if cache_params is None:
cache_params = MambaCache(
cache_params = FalconMambaCache(
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
)
cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
@@ -623,6 +731,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
)
else:
cache_params = None
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
for mixer_block in self.layers:
@@ -653,11 +762,10 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
@auto_docstring(
custom_intro="""
The FALCONMAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input
The FALCON_MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
"""
)
# Copied from transformers.models.mamba.modeling_mamba.MambaForCausalLM with MAMBA->FALCONMAMBA,Mamba->FalconMamba,mamba->falcon_mamba,FalconMambaCache->MambaCache
class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
@@ -704,38 +812,32 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
input_ids,
inputs_embeds=None,
use_cache=None,
cache_params: Optional[MambaCache] = None,
cache_params: Optional[FalconMambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**kwargs,
):
# Overwritten -- uses `cache_params` as opposed to `past_key_values`
if use_cache:
# `cache_position` should have been initialized in `generate`
if cache_position is None:
raise ValueError(
"`cache_position` should not be None as it should have been initialized in "
"`model.generate`, you are responsible for passing in a valid `cache_position` if "
"you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
)
if cache_position[0] > 0:
input_ids = input_ids[:, -1].unsqueeze(-1)
if attention_mask is not None:
attention_mask = None
model_inputs = {"input_ids": input_ids.contiguous()}
if use_cache and cache_params is None:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
# will be applied when it is longer, so it will be equivalent to always have it match
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device)
if inputs_embeds is not None:
model_inputs = {"inputs_embeds": inputs_embeds}
max_batch_size = inputs_embeds.size(0)
else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
# will be applied when it is longer, so it will be equivalent to always have it match
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
max_batch_size = input_ids.size(0)
cache_params = FalconMambaCache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype)
if inputs_embeds is not None and cache_params is None:
if use_cache and cache_position[0] > 0:
model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous()
attention_mask = None
if not use_cache and inputs_embeds is not None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
@@ -753,7 +855,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[MambaCache] = None,
cache_params: Optional[FalconMambaCache] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@@ -762,7 +864,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
**kwargs, # for now we need this for generation
) -> Union[tuple, FalconMambaCausalLMOutput]:
r"""
cache_params (`MambaCache`, *optional*):
cache_params (`FalconMambaCache`, *optional*):
If passed along, the model uses the previous state in all the blocks (which will give the output for the
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -811,4 +913,4 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
)
__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel"]
__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel", "FalconMambaCache"]

View File

@@ -0,0 +1,540 @@
# coding=utf-8
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch FALCONMAMBA model."""
from typing import Optional
import torch
import torch.utils.checkpoint
from torch import nn
from ...utils import auto_docstring, logging
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
from ..mamba.configuration_mamba import MambaConfig
from ..mamba.modeling_mamba import (
MambaBlock,
MambaCache,
MambaCausalLMOutput,
MambaForCausalLM,
MambaMixer,
MambaModel,
MambaOutput,
MambaPreTrainedModel,
MambaRMSNorm,
)
logger = logging.get_logger(__name__)
if is_mambapy_available():
from mambapy.pscan import pscan
else:
pscan = None
if is_mamba_ssm_available():
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from ...kernels.falcon_mamba import mamba_inner_fn
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None
class FalconMambaConfig(MambaConfig):
"""
This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the FALCON_MAMBA
[tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50280):
Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`FalconMambaModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the embeddings and hidden states.
state_size (`int`, *optional*, defaults to 16): shape of the state space latents.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the model.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
The epsilon to use in the layer normalization layers.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 0):
The id of the beginning of sentence token in the vocabulary.
eos_token_id (`int`, *optional*, defaults to 0):
The id of the end of sentence token in the vocabulary.
expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
use_bias (`bool`, *optional*, defaults to `False`):
Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
use_conv_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias in the convolution layer of the mixer block.
hidden_act (`str`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
initializer_range (`float`, *optional*, defaults to 0.1):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
residual_in_fp32 (`bool`, *optional*, defaults to `True`):
Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
time_step_scale (`float`, *optional*, defaults to 1.0):
Scale used used to scale `dt_proj.bias`.
time_step_min (`float`, *optional*, defaults to 0.001):
Minimum `time_step` used to bound `dt_proj.bias`.
time_step_max (`float`, *optional*, defaults to 0.1):
Maximum `time_step` used to bound `dt_proj.bias`.
time_step_init_scheme (`float`, *optional*, defaults to `"random"`):
Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`
time_step_floor (`float`, *optional*, defaults to 0.0001):
Minimum clamping value of the `dt_proj.bias` layer initialization.
rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
Whether or not to rescale `out_proj` weights when initializing.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the cache should be used.
use_falcon_mambapy (`bool`, *optional*, defaults to `False`):
This argument corresponds to `use_mambapy` in MambaConfig.
Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not available. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
Example:
```python
>>> from transformers import FalconMambaConfig, FalconMambaModel
>>> # Initializing a FalconMamba configuration
>>> configuration = FalconMambaConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = FalconMambaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
def __init__(
self,
vocab_size=50280,
hidden_size=768,
state_size=16,
num_hidden_layers=32,
layer_norm_epsilon=1e-5,
pad_token_id=0,
bos_token_id=0,
eos_token_id=0,
expand=2,
conv_kernel=4,
use_bias=False,
use_conv_bias=True,
hidden_act="silu",
initializer_range=0.1,
residual_in_fp32=True,
time_step_rank="auto",
time_step_scale=1.0,
time_step_min=0.001,
time_step_max=0.1,
time_step_init_scheme="random",
time_step_floor=1e-4,
rescale_prenorm_residual=False,
use_cache=True,
use_falcon_mambapy=False,
mixer_rms_eps=1e-6,
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
state_size=state_size,
num_hidden_layers=num_hidden_layers,
layer_norm_epsilon=layer_norm_epsilon,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
expand=expand,
conv_kernel=conv_kernel,
use_bias=use_bias,
use_conv_bias=use_conv_bias,
hidden_act=hidden_act,
initializer_range=initializer_range,
residual_in_fp32=residual_in_fp32,
time_step_rank=time_step_rank,
time_step_scale=time_step_scale,
time_step_min=time_step_min,
time_step_max=time_step_max,
time_step_init_scheme=time_step_init_scheme,
time_step_floor=time_step_floor,
rescale_prenorm_residual=rescale_prenorm_residual,
use_cache=use_cache,
use_falcon_mambapy=use_falcon_mambapy,
**kwargs,
)
self.mixer_rms_eps = mixer_rms_eps
class FalconMambaCache(MambaCache):
pass
def rms_forward(hidden_states, variance_epsilon=1e-6):
"""
Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
leverage this in order to multiply the final result with the RMSNorm weight
Args:
hidden_states (`torch.Tensor`):
Hidden states to normalize
variance_epsilon (`float`):
The eps value to add in the square root scaling factor
"""
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
return hidden_states.to(input_dtype)
class FalconMambaMixer(MambaMixer):
def warn_slow_implementation(self):
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
if not is_fast_path_available:
if self.use_falcon_mambapy:
if is_mambapy_available():
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
else:
raise ImportError(
"use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
)
else:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
)
def __init__(self, config: FalconMambaConfig, layer_idx: int):
super().__init__(config, layer_idx)
# Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
self.register_buffer(
"b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
)
self.register_buffer(
"dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False
)
self.rms_eps = config.mixer_rms_eps
def cuda_kernels_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[FalconMambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)
if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
contextualized_states = mamba_inner_fn(
projected_states,
self.conv1d.weight,
self.conv1d.bias if self.use_conv_bias else None,
self.x_proj.weight,
self.dt_proj.weight,
self.out_proj.weight,
self.out_proj.bias.float() if self.use_bias else None,
-torch.exp(self.A_log.float()),
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
b_rms_weight=self.b_c_rms,
c_rms_weight=self.b_c_rms,
dt_rms_weight=self.dt_rms,
b_c_dt_rms_eps=self.rms_eps,
)
else:
hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if cache_params is not None and cache_position[0] > 0:
hidden_states = causal_conv1d_update(
hidden_states.squeeze(-1),
cache_params.conv_states[self.layer_idx],
conv_weights,
self.conv1d.bias,
self.activation,
)
hidden_states = hidden_states.unsqueeze(-1)
else:
if cache_params is not None:
conv_states = nn.functional.pad(
hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
)
cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
hidden_states = causal_conv1d_fn(
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
)
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
B = rms_forward(B, variance_epsilon=self.rms_eps)
C = rms_forward(C, variance_epsilon=self.rms_eps)
time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
# In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
# at the price of a small overhead.
if hasattr(self.config, "_pre_quantization_dtype"):
discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
else:
discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
A = -torch.exp(self.A_log.float())
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
if cache_params is not None and cache_position[0] > 0:
scan_outputs = selective_state_update(
cache_params.ssm_states[self.layer_idx],
hidden_states[..., 0],
discrete_time_step[..., 0],
A,
B[:, 0],
C[:, 0],
self.D,
gate[..., 0],
time_proj_bias,
dt_softplus=True,
).unsqueeze(-1)
else:
scan_outputs, ssm_state = selective_scan_fn(
hidden_states,
discrete_time_step,
A,
B.transpose(1, 2),
C.transpose(1, 2),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
return_last_state=True,
)
if ssm_state is not None and cache_params is not None:
cache_params.update_ssm_state(self.layer_idx, ssm_state)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
return contextualized_states
def slow_forward(
self,
input_states,
cache_params: Optional[FalconMambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 2. Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
ssm_state = ssm_state.to(hidden_states.device)
# use `cache_position.shape[0]` to check whether we are in prefill
# stage, it's equivalent to check `cache_position[0] == 0`, which
# breaks dynamo fullgraph constraints
if cache_position is not None and cache_position.shape[0] == self.conv_kernel_size:
conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
hidden_states = self.act(
self.conv1d(hidden_states)[..., :seq_len]
) # [batch, intermediate_size, seq_len]
else:
conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
conv_state = conv_state.to(self.conv1d.weight.device)
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = (
self.act(hidden_states).to(dtype).unsqueeze(-1)
) # [batch, intermediate_size, 1] : decoding
else:
ssm_state = torch.zeros(
(batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype
)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
B = rms_forward(B, variance_epsilon=self.rms_eps)
C = rms_forward(C, variance_epsilon=self.rms_eps)
time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(
1, 2
) # [batch, intermediate_size, seq_len]
# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
discrete_A = torch.exp(
A[None, :, None, :] * discrete_time_step[:, :, :, None]
) # [batch, intermediate_size, seq_len, ssm_state_size]
discrete_B = (
discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
) # [batch, intermediate_size, seq_len, ssm_state_size]
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
if self.use_falcon_mambapy and self.training and cache_params is None:
hs = pscan(
discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)
) # [batch, seq_len, intermediate_size, ssm_state_size]
scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
scan_output = scan_output + hidden_states * self.D[None, :, None]
scan_output = scan_output * self.act(gate)
else:
scan_outputs = []
for i in range(seq_len):
ssm_state = (
discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
) # [batch, intermediate_size, ssm_state]
scan_output = torch.matmul(
ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)
) # [batch, intermediate_size, 1]
scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = scan_output * self.act(gate)
if cache_params is not None:
cache_params.update_ssm_state(self.layer_idx, ssm_state)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
return contextualized_states
def forward(
self,
hidden_states,
cache_params: Optional[FalconMambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
class FalconMambaRMSNorm(MambaRMSNorm):
def forward(self, hidden_states):
return self.weight.to(hidden_states.device) * rms_forward(
hidden_states, variance_epsilon=self.variance_epsilon
)
class FalconMambaBlock(MambaBlock):
pass
@auto_docstring
class FalconMambaPreTrainedModel(MambaPreTrainedModel):
pass
class FalconMambaOutput(MambaOutput):
pass
class FalconMambaCausalLMOutput(MambaCausalLMOutput):
pass
class FalconMambaModel(MambaModel, FalconMambaPreTrainedModel):
def __init__(self, config):
FalconMambaPreTrainedModel.__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[FalconMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
# Initialize weights and apply final processing
self.post_init()
def load_hook(self, state_dict, prefix, *args):
raise AttributeError("Not needed for FalconMamba")
class FalconMambaForCausalLM(MambaForCausalLM):
pass
__all__ = [
"FalconMambaForCausalLM",
"FalconMambaModel",
"FalconMambaPreTrainedModel",
"FalconMambaCache",
"FalconMambaConfig",
]

View File

@@ -24,7 +24,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import MambaCache
from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel
@@ -55,9 +55,106 @@ if is_causal_conv1d_available():
else:
causal_conv1d_update, causal_conv1d_fn = None, None
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
class MambaCache:
"""
Cache for mamba model which does not have attention mechanism and key value states.
Arguments:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used.
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The default `dtype` to use when initializing the layer.
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
Example:
```python
>>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
>>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values
MambaCache()
```
"""
is_compileable = True
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
dtype: torch.dtype = torch.float16,
device: Union[torch.device, str, None] = None,
):
self.max_batch_size = max_batch_size
self._dtype = dtype
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.conv_states: list[torch.Tensor] = []
self.ssm_states: list[torch.Tensor] = []
device = torch.device(device) if device is not None else None
for _ in range(config.num_hidden_layers):
conv_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
dtype=self._dtype,
)
ssm_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
dtype=self._dtype,
)
torch._dynamo.mark_static_address(conv_state)
torch._dynamo.mark_static_address(ssm_state)
self.conv_states.append(conv_state)
self.ssm_states.append(ssm_state)
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
# This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
# when the cache is initialized in the forward pass (e.g. Mamba)
if self.conv_states[layer_idx].device != new_conv_state.device:
self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
self.ssm_states[layer_idx].zero_()
self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device)
return self.ssm_states[layer_idx]
def reset(self):
for layer_idx in range(len(self.conv_states)):
# In-place ops prevent breaking the static address
self.conv_states[layer_idx].zero_()
self.ssm_states[layer_idx].zero_()
class MambaMixer(nn.Module):
@@ -109,6 +206,12 @@ class MambaMixer(nn.Module):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias
self.warn_slow_implementation()
def warn_slow_implementation(self):
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
if not is_fast_path_available:
if self.use_mambapy:
if is_mambapy_available():
@@ -319,6 +422,9 @@ class MambaMixer(nn.Module):
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
@@ -650,32 +756,26 @@ class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin):
**kwargs,
):
# Overwritten -- uses `cache_params` as opposed to `past_key_values`
if use_cache:
# `cache_position` should have been initialized in `generate`
if cache_position is None:
raise ValueError(
"`cache_position` should not be None as it should have been initialized in "
"`model.generate`, you are responsible for passing in a valid `cache_position` if "
"you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
)
if cache_position[0] > 0:
input_ids = input_ids[:, -1].unsqueeze(-1)
if attention_mask is not None:
attention_mask = None
model_inputs = {"input_ids": input_ids.contiguous()}
if use_cache and cache_params is None:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
# will be applied when it is longer, so it will be equivalent to always have it match
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device)
if inputs_embeds is not None:
model_inputs = {"inputs_embeds": inputs_embeds}
max_batch_size = inputs_embeds.size(0)
else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
# will be applied when it is longer, so it will be equivalent to always have it match
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
max_batch_size = input_ids.size(0)
cache_params = MambaCache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype)
if inputs_embeds is not None and cache_params is None:
if use_cache and cache_position[0] > 0:
model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous()
attention_mask = None
if not use_cache and inputs_embeds is not None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
@@ -751,4 +851,4 @@ class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin):
)
__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel"]
__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel", "MambaCache"]

View File

@@ -970,38 +970,33 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
**kwargs,
):
# Overwritten -- uses `cache_params` as opposed to `past_key_values`
if use_cache:
# `cache_position` should have been initialized in `generate`
if cache_position is None:
raise ValueError(
"`cache_position` should not be None as it should have been initialized in "
"`model.generate`, you are responsible for passing in a valid `cache_position` if "
"you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
)
if cache_position[0] > 0:
input_ids = input_ids[:, -1][..., None]
if attention_mask is not None:
attention_mask = None
model_inputs = {"input_ids": input_ids.contiguous()}
if use_cache and cache_params is None:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
# will be applied when it is longer, so it will be equivalent to always have it match
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device)
if inputs_embeds is not None:
model_inputs = {"inputs_embeds": inputs_embeds}
max_batch_size = inputs_embeds.size(0)
else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
# will be applied when it is longer, so it will be equivalent to always have it match
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
max_batch_size = input_ids.size(0)
cache_params = Mamba2Cache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype)
if inputs_embeds is not None and cache_params is None:
if use_cache and cache_position[0] > 0:
model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous()
attention_mask = None
if not use_cache and inputs_embeds is not None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"cache_params": cache_params,
"use_cache": use_cache,
"cache_position": cache_position,
"attention_mask": attention_mask,
}
)
return model_inputs

View File

@@ -44,13 +44,6 @@ class HybridCache(metaclass=DummyObject):
requires_backends(self, ["torch"])
class MambaCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class OffloadedCache(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -26,7 +26,6 @@ from transformers.testing_utils import (
require_torch_accelerator,
require_torch_large_accelerator,
require_torch_multi_accelerator,
require_torch_multi_gpu,
slow,
torch_device,
)
@@ -41,10 +40,10 @@ if is_torch_available():
import torch
from transformers import (
FalconMambaCache,
FalconMambaForCausalLM,
FalconMambaModel,
)
from transformers.cache_utils import MambaCache
# Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba
@@ -312,31 +311,6 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
def test_config(self):
self.config_tester.run_common_tests()
@require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present.
blacklist_non_batched_params = ["cache_params"]
for k in blacklist_non_batched_params:
inputs_dict.pop(k, None)
# move input tensors to cuda:O
for k, v in inputs_dict.items():
if torch.is_tensor(v):
inputs_dict[k] = v.to(0)
for model_class in self.all_model_classes:
model = model_class(config=config)
model.to(0)
model.eval()
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))
def test_falcon_mamba_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_falcon_mamba_model(*config_and_inputs)
@@ -411,7 +385,7 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, MambaCache): # MODIFIED PART START
if isinstance(tuple_object, FalconMambaCache): # MODIFIED PART START
recursive_check(tuple_object.conv_states, dict_object.conv_states)
recursive_check(tuple_object.ssm_states, dict_object.ssm_states)
elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END
@@ -458,6 +432,10 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
@unittest.skip("Mamba models do not support DDP.")
def test_multi_gpu_data_parallel_forward(self):
pass
@require_torch
@require_torch_accelerator
@@ -497,7 +475,9 @@ class FalconMambaIntegrationTests(unittest.TestCase):
@require_bitsandbytes
def test_generation_4bit(self):
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config)
model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config).to(
torch_device
)
inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device)
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@@ -513,6 +493,7 @@ class FalconMambaIntegrationTests(unittest.TestCase):
inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device)
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
print(self.tokenizer.batch_decode(out, skip_special_tokens=False)[0])
self.assertEqual(
self.tokenizer.batch_decode(out, skip_special_tokens=False)[0],
@@ -543,7 +524,7 @@ class FalconMambaIntegrationTests(unittest.TestCase):
inputs = tok(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(torch_device)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=0, torch_dtype=torch.float16)
out = model.generate(**inputs, max_new_tokens=20)
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
out = tok.batch_decode(out, skip_special_tokens=True)
self.assertListEqual(out, EXPECTED_OUTPUT)
@@ -553,7 +534,7 @@ class FalconMambaIntegrationTests(unittest.TestCase):
inputs_embeds = model.get_input_embeddings()(inputs.pop("input_ids"))
inputs["inputs_embeds"] = inputs_embeds
out = model.generate(**inputs, max_new_tokens=20)
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
out = tok.batch_decode(out, skip_special_tokens=True)
EXPECTED_OUTPUTS = Expectations(

View File

@@ -20,7 +20,7 @@ from unittest.util import safe_repr
from parameterized import parameterized
from transformers import AutoTokenizer, MambaConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -32,10 +32,10 @@ if is_torch_available():
import torch
from transformers import (
MambaCache,
MambaForCausalLM,
MambaModel,
)
from transformers.models.mamba.modeling_mamba import MambaCache
class MambaModelTester:
@@ -279,31 +279,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_config(self):
self.config_tester.run_common_tests()
@require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present.
blacklist_non_batched_params = ["cache_params"]
for k in blacklist_non_batched_params:
inputs_dict.pop(k, None)
# move input tensors to cuda:O
for k, v in inputs_dict.items():
if torch.is_tensor(v):
inputs_dict[k] = v.to(0)
for model_class in self.all_model_classes:
model = model_class(config=config)
model.to(0)
model.eval()
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))
def test_mamba_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba_model(*config_and_inputs)
@@ -452,6 +427,10 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size),
)
@unittest.skip("Mamba models do not support DDP.")
def test_multi_gpu_data_parallel_forward(self):
pass
@require_torch
class MambaIntegrationTests(unittest.TestCase):
@@ -547,11 +526,11 @@ class MambaIntegrationTests(unittest.TestCase):
torch_device
)
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
output = model.generate(input_ids, max_new_tokens=20)
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
output = model.generate(input_ids, max_new_tokens=20)
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)

View File

@@ -62,8 +62,6 @@ if is_torch_available():
TEST_CACHE_IMPLEMENTATIONS = [
cache_name
for cache_name in ALL_CACHE_IMPLEMENTATIONS
# TODO (joao): Mamba is not compatible with most models, remove from `ALL_CACHE_IMPLEMENTATIONS`?
if cache_name != "mamba"
# TODO (joao): offloaded_hybrid == offloaded_hybrid_chunked, deprecate one of them
if cache_name != "offloaded_hybrid"
]

View File

@@ -141,7 +141,12 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
return updated_node
def leave_ImportFrom(self, original_node, updated_node):
"""The imports from other file types (configuration, processing etc) should use original model name."""
"""
The imports from other file types (configuration, processing etc) should use original model name.
Also, no replaces on absolute imports (e.g. `from mamba_ssm import ...`)
"""
if len(original_node.relative) == 0: # no replaces on absolute imports
return original_node
if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()):
patterns = "|".join(ALL_FILE_TYPES)
regex = rf"({patterns})_{self.new_name}"