Refactor MambaCache to modeling_mamba.py (#38086)
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162c7c5c5ce8a73ccf48cffc8a96343ebb33. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354fcab630de5899321f5daced5fb312c2a2. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf160ee251536f07c40b8a0857d499e70db6. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
This commit is contained in:
committed by
GitHub
parent
a419a40234
commit
1aa7256f01
@@ -110,6 +110,13 @@ outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## FalconMambaCache
|
||||
|
||||
[[autodoc]] FalconMambaCache
|
||||
- update_conv_state
|
||||
- update_ssm_state
|
||||
- reset
|
||||
|
||||
## FalconMambaConfig
|
||||
|
||||
[[autodoc]] FalconMambaConfig
|
||||
|
||||
@@ -116,6 +116,13 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## MambaCache
|
||||
|
||||
[[autodoc]] MambaCache
|
||||
- update_conv_state
|
||||
- update_ssm_state
|
||||
- reset
|
||||
|
||||
## MambaConfig
|
||||
|
||||
[[autodoc]] MambaConfig
|
||||
|
||||
@@ -371,7 +371,6 @@ else:
|
||||
"EncoderDecoderCache",
|
||||
"HQQQuantizedCache",
|
||||
"HybridCache",
|
||||
"MambaCache",
|
||||
"OffloadedCache",
|
||||
"OffloadedStaticCache",
|
||||
"QuantizedCache",
|
||||
|
||||
@@ -2,6 +2,7 @@ import copy
|
||||
import importlib.metadata
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
@@ -18,6 +19,7 @@ from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_great
|
||||
if is_hqq_available():
|
||||
from hqq.core.quantize import Quantizer as HQQQuantizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -2091,106 +2093,6 @@ class OffloadedHybridCache(HybridChunkedCache):
|
||||
self.device_value_cache[self.active_device_layer].fill_(0.0)
|
||||
|
||||
|
||||
class MambaCache:
|
||||
"""
|
||||
Cache for mamba model which does not have attention mechanism and key value states.
|
||||
|
||||
Arguments:
|
||||
config (`PretrainedConfig):
|
||||
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||||
max_batch_size (`int`):
|
||||
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used.
|
||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
device (`torch.device` or `str`, *optional*):
|
||||
The device on which the cache should be initialized. Should be the same as the layer.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
|
||||
|
||||
>>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
|
||||
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
|
||||
|
||||
>>> # Prepare a cache class and pass it to model's forward
|
||||
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
|
||||
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||
>>> outputs.past_key_values
|
||||
MambaCache()
|
||||
```
|
||||
"""
|
||||
|
||||
is_compileable = True
|
||||
|
||||
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
max_batch_size: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: Union[torch.device, str, None] = None,
|
||||
):
|
||||
self.max_batch_size = max_batch_size
|
||||
self._dtype = dtype
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.ssm_state_size = config.state_size
|
||||
self.conv_kernel_size = config.conv_kernel
|
||||
|
||||
self.conv_states: list[torch.Tensor] = []
|
||||
self.ssm_states: list[torch.Tensor] = []
|
||||
device = torch.device(device) if device is not None else None
|
||||
for _ in range(config.num_hidden_layers):
|
||||
conv_state: torch.Tensor = torch.zeros(
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.conv_kernel_size,
|
||||
device=device,
|
||||
dtype=self._dtype,
|
||||
)
|
||||
ssm_state: torch.Tensor = torch.zeros(
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.ssm_state_size,
|
||||
device=device,
|
||||
dtype=self._dtype,
|
||||
)
|
||||
|
||||
torch._dynamo.mark_static_address(conv_state)
|
||||
torch._dynamo.mark_static_address(ssm_state)
|
||||
self.conv_states.append(conv_state)
|
||||
self.ssm_states.append(ssm_state)
|
||||
|
||||
def update_conv_state(
|
||||
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
|
||||
) -> torch.Tensor:
|
||||
# This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
|
||||
# when the cache is initialized in the forward pass (e.g. Mamba)
|
||||
if self.conv_states[layer_idx].device != new_conv_state.device:
|
||||
self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
|
||||
|
||||
conv_state = self.conv_states[layer_idx]
|
||||
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
|
||||
|
||||
conv_state = conv_state.roll(shifts=-1, dims=-1)
|
||||
conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
|
||||
self.conv_states[layer_idx].zero_()
|
||||
self.conv_states[layer_idx] += conv_state
|
||||
return self.conv_states[layer_idx]
|
||||
|
||||
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
|
||||
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
|
||||
return self.ssm_states[layer_idx]
|
||||
|
||||
def reset(self):
|
||||
for layer_idx in range(len(self.conv_states)):
|
||||
# In-place ops prevent breaking the static address
|
||||
self.conv_states[layer_idx].zero_()
|
||||
self.ssm_states[layer_idx].zero_()
|
||||
|
||||
|
||||
class OffloadedStaticCache(StaticCache):
|
||||
"""
|
||||
Static cache class to be used with `torch.compile(model)` that offloads to the CPU or
|
||||
@@ -2461,3 +2363,122 @@ class OffloadedStaticCache(StaticCache):
|
||||
|
||||
self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True)
|
||||
self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True)
|
||||
|
||||
|
||||
# TODO (manuel, joao): remove this class, it is here only for backwards compatibility
|
||||
# PEP 562: Lazy loading for deprecated location of MambaCache
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "MambaCache":
|
||||
warnings.warn(
|
||||
(
|
||||
"Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed "
|
||||
"in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead."
|
||||
),
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
class MambaCache:
|
||||
"""
|
||||
Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed
|
||||
in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead.
|
||||
|
||||
Cache for mamba model which does not have attention mechanism and key value states.
|
||||
|
||||
Arguments:
|
||||
config (`PretrainedConfig):
|
||||
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||||
max_batch_size (`int`):
|
||||
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used.
|
||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
device (`torch.device` or `str`, *optional*):
|
||||
The device on which the cache should be initialized. Should be the same as the layer.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
|
||||
|
||||
>>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
|
||||
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
|
||||
|
||||
>>> # Prepare a cache class and pass it to model's forward
|
||||
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
|
||||
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||
>>> outputs.past_key_values
|
||||
MambaCache()
|
||||
```
|
||||
"""
|
||||
|
||||
is_compileable = True
|
||||
|
||||
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_batch_size: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: Union[torch.device, str, None] = None,
|
||||
):
|
||||
self.max_batch_size = max_batch_size
|
||||
self._dtype = dtype
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.ssm_state_size = config.state_size
|
||||
self.conv_kernel_size = config.conv_kernel
|
||||
|
||||
self.conv_states: list[torch.Tensor] = []
|
||||
self.ssm_states: list[torch.Tensor] = []
|
||||
device = torch.device(device) if device is not None else None
|
||||
for _ in range(config.num_hidden_layers):
|
||||
conv_state: torch.Tensor = torch.zeros(
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.conv_kernel_size,
|
||||
device=device,
|
||||
dtype=self._dtype,
|
||||
)
|
||||
ssm_state: torch.Tensor = torch.zeros(
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.ssm_state_size,
|
||||
device=device,
|
||||
dtype=self._dtype,
|
||||
)
|
||||
|
||||
torch._dynamo.mark_static_address(conv_state)
|
||||
torch._dynamo.mark_static_address(ssm_state)
|
||||
self.conv_states.append(conv_state)
|
||||
self.ssm_states.append(ssm_state)
|
||||
|
||||
def update_conv_state(
|
||||
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
|
||||
) -> torch.Tensor:
|
||||
# This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
|
||||
# when the cache is initialized in the forward pass (e.g. Mamba)
|
||||
if self.conv_states[layer_idx].device != new_conv_state.device:
|
||||
self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
|
||||
|
||||
conv_state = self.conv_states[layer_idx]
|
||||
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
|
||||
|
||||
conv_state = conv_state.roll(shifts=-1, dims=-1)
|
||||
conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
|
||||
self.conv_states[layer_idx].zero_()
|
||||
self.conv_states[layer_idx] += conv_state
|
||||
return self.conv_states[layer_idx]
|
||||
|
||||
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
|
||||
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
|
||||
return self.ssm_states[layer_idx]
|
||||
|
||||
def reset(self):
|
||||
for layer_idx in range(len(self.conv_states)):
|
||||
# In-place ops prevent breaking the static address
|
||||
self.conv_states[layer_idx].zero_()
|
||||
self.ssm_states[layer_idx].zero_()
|
||||
|
||||
return MambaCache
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -54,7 +54,6 @@ if is_torch_available():
|
||||
HQQQuantizedCache,
|
||||
HybridCache,
|
||||
HybridChunkedCache,
|
||||
MambaCache,
|
||||
OffloadedHybridCache,
|
||||
OffloadedStaticCache,
|
||||
QuantizedCacheConfig,
|
||||
@@ -75,7 +74,6 @@ if is_torch_available():
|
||||
"hybrid_chunked": HybridChunkedCache,
|
||||
"offloaded_hybrid": OffloadedHybridCache,
|
||||
"offloaded_hybrid_chunked": OffloadedHybridCache,
|
||||
"mamba": MambaCache,
|
||||
}
|
||||
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
||||
ALL_CACHE_IMPLEMENTATIONS = (
|
||||
@@ -186,7 +184,6 @@ class GenerationConfig(PushToHubMixin):
|
||||
- `"offloaded_static"`: [`OffloadedStaticCache`]
|
||||
- `"sliding_window"`: [`SlidingWindowCache`]
|
||||
- `"hybrid"`: [`HybridCache`]
|
||||
- `"mamba"`: [`MambaCache`]
|
||||
- `"quantized"`: [`QuantizedCache`]
|
||||
|
||||
If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See
|
||||
|
||||
@@ -1916,9 +1916,8 @@ class GenerationMixin(ContinuousMixin):
|
||||
or isinstance(
|
||||
cache_to_check, (HybridChunkedCache, OffloadedHybridCache)
|
||||
) # due to internal slicing, we always re-init
|
||||
or cache_to_check.max_cache_len < max_cache_len
|
||||
)
|
||||
if cache_implementation != "mamba":
|
||||
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
|
||||
|
||||
if requires_cross_attention_cache and hasattr(self, "_cache"):
|
||||
need_new_cache = (
|
||||
@@ -1957,9 +1956,9 @@ class GenerationMixin(ContinuousMixin):
|
||||
def _supports_default_dynamic_cache(cls) -> bool:
|
||||
"""
|
||||
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
|
||||
This adds exception for some models like `Jamba` model which uses its own `HybridMambaAttentionDynamicCache`
|
||||
This adds exception for some models like `Mamba` models which use their own caches
|
||||
and do not need to initialize the Cache in advance in order to save memory (because no back and forth
|
||||
`to_legacy_cache` and `from_legacy_cache` will be performed for `HybridMambaAttentionDynamicCache`).
|
||||
`to_legacy_cache` and `from_legacy_cache` will be performed for mamba-based models).
|
||||
"""
|
||||
# NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name
|
||||
return not cls._is_stateful and all(
|
||||
@@ -2016,7 +2015,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
if generation_config.use_cache is False:
|
||||
return
|
||||
|
||||
# Quick escape route 3: model that only supports legacy caches = nothing to prepare
|
||||
# Quick escape route 3: model that only supports legacy caches or models that supply it in `prepare_inputs_for_generation` (mamba, zamba, ...)
|
||||
if not self._supports_default_dynamic_cache():
|
||||
if generation_config.cache_implementation is not None:
|
||||
warnings.warn(
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_falcon_mamba.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,15 +18,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""FALCONMAMBA configuration"""
|
||||
|
||||
import math
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FalconMambaConfig(PretrainedConfig):
|
||||
@@ -79,10 +80,13 @@ class FalconMambaConfig(PretrainedConfig):
|
||||
Whether or not to rescale `out_proj` weights when initializing.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the cache should be used.
|
||||
use_mambapy (`bool`, *optional*, defaults to `False`):
|
||||
Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not available. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
|
||||
use_falcon_mambapy (`bool`, *optional*, defaults to `False`):
|
||||
This argument corresponds to `use_mambapy` in MambaConfig.
|
||||
Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not available. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
|
||||
mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
@@ -125,10 +129,11 @@ class FalconMambaConfig(PretrainedConfig):
|
||||
time_step_floor=1e-4,
|
||||
rescale_prenorm_residual=False,
|
||||
use_cache=True,
|
||||
use_mambapy=False,
|
||||
use_falcon_mambapy=False,
|
||||
mixer_rms_eps=1e-6,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.state_size = state_size
|
||||
@@ -153,10 +158,8 @@ class FalconMambaConfig(PretrainedConfig):
|
||||
self.rescale_prenorm_residual = rescale_prenorm_residual
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.use_cache = use_cache
|
||||
self.use_mambapy = use_mambapy
|
||||
self.use_falcon_mambapy = use_falcon_mambapy
|
||||
self.mixer_rms_eps = mixer_rms_eps
|
||||
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["FalconMambaConfig"]
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_falcon_mamba.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
|
||||
#
|
||||
@@ -12,29 +18,29 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch FALCONMAMBA model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import MambaCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
|
||||
from ...utils.import_utils import (
|
||||
is_causal_conv1d_available,
|
||||
is_mamba_ssm_available,
|
||||
is_mambapy_available,
|
||||
)
|
||||
from .configuration_falcon_mamba import FalconMambaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_mambapy_available():
|
||||
from mambapy.pscan import pscan
|
||||
else:
|
||||
@@ -53,9 +59,109 @@ if is_causal_conv1d_available():
|
||||
else:
|
||||
causal_conv1d_update, causal_conv1d_fn = None, None
|
||||
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FalconMambaCache:
|
||||
"""
|
||||
Cache for falcon_mamba model which does not have attention mechanism and key value states.
|
||||
|
||||
Arguments:
|
||||
config (`PretrainedConfig):
|
||||
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||||
max_batch_size (`int`):
|
||||
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used.
|
||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
device (`torch.device` or `str`, *optional*):
|
||||
The device on which the cache should be initialized. Should be the same as the layer.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache
|
||||
|
||||
>>> model = FalconMambaForCausalLM.from_pretrained("state-spaces/falcon_mamba-130m-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/falcon_mamba-130m-hf")
|
||||
|
||||
>>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt")
|
||||
|
||||
>>> # Prepare a cache class and pass it to model's forward
|
||||
>>> past_key_values = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
|
||||
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||
>>> outputs.past_key_values
|
||||
FalconMambaCache()
|
||||
```
|
||||
"""
|
||||
|
||||
is_compileable = True
|
||||
|
||||
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
max_batch_size: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: Union[torch.device, str, None] = None,
|
||||
):
|
||||
self.max_batch_size = max_batch_size
|
||||
self._dtype = dtype
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.ssm_state_size = config.state_size
|
||||
self.conv_kernel_size = config.conv_kernel
|
||||
|
||||
self.conv_states: list[torch.Tensor] = []
|
||||
self.ssm_states: list[torch.Tensor] = []
|
||||
device = torch.device(device) if device is not None else None
|
||||
for _ in range(config.num_hidden_layers):
|
||||
conv_state: torch.Tensor = torch.zeros(
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.conv_kernel_size,
|
||||
device=device,
|
||||
dtype=self._dtype,
|
||||
)
|
||||
ssm_state: torch.Tensor = torch.zeros(
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.ssm_state_size,
|
||||
device=device,
|
||||
dtype=self._dtype,
|
||||
)
|
||||
|
||||
torch._dynamo.mark_static_address(conv_state)
|
||||
torch._dynamo.mark_static_address(ssm_state)
|
||||
self.conv_states.append(conv_state)
|
||||
self.ssm_states.append(ssm_state)
|
||||
|
||||
def update_conv_state(
|
||||
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
|
||||
) -> torch.Tensor:
|
||||
# This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
|
||||
# when the cache is initialized in the forward pass (e.g. FalconMamba)
|
||||
if self.conv_states[layer_idx].device != new_conv_state.device:
|
||||
self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
|
||||
|
||||
conv_state = self.conv_states[layer_idx]
|
||||
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
|
||||
|
||||
conv_state = conv_state.roll(shifts=-1, dims=-1)
|
||||
conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
|
||||
self.conv_states[layer_idx].zero_()
|
||||
self.conv_states[layer_idx] += conv_state
|
||||
return self.conv_states[layer_idx]
|
||||
|
||||
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
|
||||
self.ssm_states[layer_idx].zero_()
|
||||
self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device)
|
||||
return self.ssm_states[layer_idx]
|
||||
|
||||
def reset(self):
|
||||
for layer_idx in range(len(self.conv_states)):
|
||||
# In-place ops prevent breaking the static address
|
||||
self.conv_states[layer_idx].zero_()
|
||||
self.ssm_states[layer_idx].zero_()
|
||||
|
||||
|
||||
def rms_forward(hidden_states, variance_epsilon=1e-6):
|
||||
@@ -107,7 +213,7 @@ class FalconMambaMixer(nn.Module):
|
||||
self.activation = config.hidden_act
|
||||
self.act = ACT2FN[config.hidden_act]
|
||||
|
||||
self.use_mambapy = config.use_mambapy
|
||||
self.use_falcon_mambapy = config.use_falcon_mambapy
|
||||
|
||||
# projection of the input hidden states
|
||||
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
|
||||
@@ -126,6 +232,7 @@ class FalconMambaMixer(nn.Module):
|
||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
|
||||
self.use_bias = config.use_bias
|
||||
|
||||
self.warn_slow_implementation()
|
||||
# Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
|
||||
self.register_buffer(
|
||||
"b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
|
||||
@@ -135,8 +242,12 @@ class FalconMambaMixer(nn.Module):
|
||||
)
|
||||
self.rms_eps = config.mixer_rms_eps
|
||||
|
||||
def warn_slow_implementation(self):
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
if not is_fast_path_available:
|
||||
if self.use_mambapy:
|
||||
if self.use_falcon_mambapy:
|
||||
if is_mambapy_available():
|
||||
logger.warning_once(
|
||||
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
|
||||
@@ -157,7 +268,7 @@ class FalconMambaMixer(nn.Module):
|
||||
def cuda_kernels_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache_params: Optional[MambaCache] = None,
|
||||
cache_params: Optional[FalconMambaCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
@@ -269,10 +380,10 @@ class FalconMambaMixer(nn.Module):
|
||||
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
|
||||
return contextualized_states
|
||||
|
||||
def slow_forward(
|
||||
self,
|
||||
# fmt: off
|
||||
def slow_forward(self,
|
||||
input_states,
|
||||
cache_params: Optional[MambaCache] = None,
|
||||
cache_params: Optional[FalconMambaCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
@@ -344,7 +455,7 @@ class FalconMambaMixer(nn.Module):
|
||||
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
|
||||
|
||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
||||
if self.use_mambapy and self.training and cache_params is None:
|
||||
if self.use_falcon_mambapy and self.training and cache_params is None:
|
||||
hs = pscan(
|
||||
discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)
|
||||
) # [batch, seq_len, intermediate_size, ssm_state_size]
|
||||
@@ -371,21 +482,23 @@ class FalconMambaMixer(nn.Module):
|
||||
# 4. Final linear projection
|
||||
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
|
||||
return contextualized_states
|
||||
# fmt: on
|
||||
|
||||
# Copied from transformers.models.mamba.modeling_mamba.MambaMixer.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cache_params: Optional[MambaCache] = None,
|
||||
cache_params: Optional[FalconMambaCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
|
||||
|
||||
# Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba
|
||||
class FalconMambaRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
@@ -395,17 +508,15 @@ class FalconMambaRMSNorm(nn.Module):
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"
|
||||
|
||||
# Ignore copy
|
||||
def forward(self, hidden_states):
|
||||
return self.weight.to(hidden_states.device) * rms_forward(
|
||||
hidden_states, variance_epsilon=self.variance_epsilon
|
||||
)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
# Copied from transformers.models.mamba.modeling_mamba.MambaBlock with Mamba->FalconMamba,FalconMambaCache->MambaCache
|
||||
class FalconMambaBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
@@ -418,7 +529,7 @@ class FalconMambaBlock(GradientCheckpointingLayer):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cache_params: Optional[MambaCache] = None,
|
||||
cache_params: Optional[FalconMambaCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
@@ -435,7 +546,6 @@ class FalconMambaBlock(GradientCheckpointingLayer):
|
||||
|
||||
|
||||
@auto_docstring
|
||||
# Copied from transformers.models.mamba.modeling_mamba.MambaPreTrainedModel with Mamba->FalconMamba
|
||||
class FalconMambaPreTrainedModel(PreTrainedModel):
|
||||
config: FalconMambaConfig
|
||||
base_model_prefix = "backbone"
|
||||
@@ -507,13 +617,12 @@ class FalconMambaPreTrainedModel(PreTrainedModel):
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Class for the FALCONMAMBA model outputs.
|
||||
Class for the FALCON_MAMBA model outputs.
|
||||
"""
|
||||
)
|
||||
# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->FALCONMAMBA,Mamba->FalconMamba,FalconMambaCache->MambaCache
|
||||
class FalconMambaOutput(ModelOutput):
|
||||
r"""
|
||||
cache_params (`MambaCache`):
|
||||
cache_params (`FalconMambaCache`):
|
||||
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
||||
avoid providing the old `input_ids`.
|
||||
|
||||
@@ -521,7 +630,7 @@ class FalconMambaOutput(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
cache_params: Optional[MambaCache] = None
|
||||
cache_params: Optional[FalconMambaCache] = None
|
||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@@ -531,14 +640,13 @@ class FalconMambaOutput(ModelOutput):
|
||||
Base class for causal language model (or autoregressive) outputs.
|
||||
"""
|
||||
)
|
||||
# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->FalconMamba,FalconMambaCache->MambaCache
|
||||
class FalconMambaCausalLMOutput(ModelOutput):
|
||||
r"""
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
cache_params (`MambaCache`):
|
||||
cache_params (`FalconMambaCache`):
|
||||
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
||||
avoid providing the old `input_ids`.
|
||||
|
||||
@@ -547,7 +655,7 @@ class FalconMambaCausalLMOutput(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
cache_params: Optional[MambaCache] = None
|
||||
cache_params: Optional[FalconMambaCache] = None
|
||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@@ -577,7 +685,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
cache_params: Optional[MambaCache] = None,
|
||||
cache_params: Optional[FalconMambaCache] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
@@ -585,7 +693,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
) -> Union[tuple, FalconMambaOutput]:
|
||||
r"""
|
||||
cache_params (`MambaCache`, *optional*):
|
||||
cache_params (`FalconMambaCache`, *optional*):
|
||||
If passed along, the model uses the previous state in all the blocks (which will give the output for the
|
||||
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
|
||||
use_cache (`bool`, *optional*):
|
||||
@@ -608,7 +716,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
|
||||
|
||||
if use_cache:
|
||||
if cache_params is None:
|
||||
cache_params = MambaCache(
|
||||
cache_params = FalconMambaCache(
|
||||
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
||||
)
|
||||
cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
|
||||
@@ -623,6 +731,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
|
||||
)
|
||||
else:
|
||||
cache_params = None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for mixer_block in self.layers:
|
||||
@@ -653,11 +762,10 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The FALCONMAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
||||
The FALCON_MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
||||
embeddings).
|
||||
"""
|
||||
)
|
||||
# Copied from transformers.models.mamba.modeling_mamba.MambaForCausalLM with MAMBA->FALCONMAMBA,Mamba->FalconMamba,mamba->falcon_mamba,FalconMambaCache->MambaCache
|
||||
class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
@@ -704,38 +812,32 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
|
||||
input_ids,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
cache_params: Optional[MambaCache] = None,
|
||||
cache_params: Optional[FalconMambaCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- uses `cache_params` as opposed to `past_key_values`
|
||||
|
||||
if use_cache:
|
||||
# `cache_position` should have been initialized in `generate`
|
||||
if cache_position is None:
|
||||
raise ValueError(
|
||||
"`cache_position` should not be None as it should have been initialized in "
|
||||
"`model.generate`, you are responsible for passing in a valid `cache_position` if "
|
||||
"you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
|
||||
)
|
||||
if cache_position[0] > 0:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = None
|
||||
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
if use_cache and cache_params is None:
|
||||
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
|
||||
# considering padding will be applied when input length is shorter, and truncation
|
||||
# will be applied when it is longer, so it will be equivalent to always have it match
|
||||
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
|
||||
cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device)
|
||||
if inputs_embeds is not None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
max_batch_size = inputs_embeds.size(0)
|
||||
else:
|
||||
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
|
||||
# considering padding will be applied when input length is shorter, and truncation
|
||||
# will be applied when it is longer, so it will be equivalent to always have it match
|
||||
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
|
||||
cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
|
||||
max_batch_size = input_ids.size(0)
|
||||
cache_params = FalconMambaCache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype)
|
||||
|
||||
if inputs_embeds is not None and cache_params is None:
|
||||
if use_cache and cache_position[0] > 0:
|
||||
model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous()
|
||||
attention_mask = None
|
||||
|
||||
if not use_cache and inputs_embeds is not None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
@@ -753,7 +855,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
cache_params: Optional[MambaCache] = None,
|
||||
cache_params: Optional[FalconMambaCache] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
@@ -762,7 +864,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
|
||||
**kwargs, # for now we need this for generation
|
||||
) -> Union[tuple, FalconMambaCausalLMOutput]:
|
||||
r"""
|
||||
cache_params (`MambaCache`, *optional*):
|
||||
cache_params (`FalconMambaCache`, *optional*):
|
||||
If passed along, the model uses the previous state in all the blocks (which will give the output for the
|
||||
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@@ -811,4 +913,4 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel"]
|
||||
__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel", "FalconMambaCache"]
|
||||
|
||||
540
src/transformers/models/falcon_mamba/modular_falcon_mamba.py
Normal file
540
src/transformers/models/falcon_mamba/modular_falcon_mamba.py
Normal file
@@ -0,0 +1,540 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch FALCONMAMBA model."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
|
||||
from ..mamba.configuration_mamba import MambaConfig
|
||||
from ..mamba.modeling_mamba import (
|
||||
MambaBlock,
|
||||
MambaCache,
|
||||
MambaCausalLMOutput,
|
||||
MambaForCausalLM,
|
||||
MambaMixer,
|
||||
MambaModel,
|
||||
MambaOutput,
|
||||
MambaPreTrainedModel,
|
||||
MambaRMSNorm,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_mambapy_available():
|
||||
from mambapy.pscan import pscan
|
||||
else:
|
||||
pscan = None
|
||||
|
||||
if is_mamba_ssm_available():
|
||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
||||
|
||||
from ...kernels.falcon_mamba import mamba_inner_fn
|
||||
else:
|
||||
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
||||
|
||||
if is_causal_conv1d_available():
|
||||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
else:
|
||||
causal_conv1d_update, causal_conv1d_fn = None, None
|
||||
|
||||
|
||||
class FalconMambaConfig(MambaConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the FALCON_MAMBA
|
||||
[tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 50280):
|
||||
Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`FalconMambaModel`].
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the embeddings and hidden states.
|
||||
state_size (`int`, *optional*, defaults to 16): shape of the state space latents.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the model.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon to use in the layer normalization layers.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 0):
|
||||
The id of the beginning of sentence token in the vocabulary.
|
||||
eos_token_id (`int`, *optional*, defaults to 0):
|
||||
The id of the end of sentence token in the vocabulary.
|
||||
expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
|
||||
conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
|
||||
use_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
|
||||
use_conv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use bias in the convolution layer of the mixer block.
|
||||
hidden_act (`str`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
initializer_range (`float`, *optional*, defaults to 0.1):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
residual_in_fp32 (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
|
||||
time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
|
||||
Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
|
||||
time_step_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale used used to scale `dt_proj.bias`.
|
||||
time_step_min (`float`, *optional*, defaults to 0.001):
|
||||
Minimum `time_step` used to bound `dt_proj.bias`.
|
||||
time_step_max (`float`, *optional*, defaults to 0.1):
|
||||
Maximum `time_step` used to bound `dt_proj.bias`.
|
||||
time_step_init_scheme (`float`, *optional*, defaults to `"random"`):
|
||||
Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`
|
||||
time_step_floor (`float`, *optional*, defaults to 0.0001):
|
||||
Minimum clamping value of the `dt_proj.bias` layer initialization.
|
||||
rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to rescale `out_proj` weights when initializing.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the cache should be used.
|
||||
use_falcon_mambapy (`bool`, *optional*, defaults to `False`):
|
||||
This argument corresponds to `use_mambapy` in MambaConfig.
|
||||
Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not available. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
|
||||
mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import FalconMambaConfig, FalconMambaModel
|
||||
|
||||
>>> # Initializing a FalconMamba configuration
|
||||
>>> configuration = FalconMambaConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the configuration
|
||||
>>> model = FalconMambaModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50280,
|
||||
hidden_size=768,
|
||||
state_size=16,
|
||||
num_hidden_layers=32,
|
||||
layer_norm_epsilon=1e-5,
|
||||
pad_token_id=0,
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
expand=2,
|
||||
conv_kernel=4,
|
||||
use_bias=False,
|
||||
use_conv_bias=True,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.1,
|
||||
residual_in_fp32=True,
|
||||
time_step_rank="auto",
|
||||
time_step_scale=1.0,
|
||||
time_step_min=0.001,
|
||||
time_step_max=0.1,
|
||||
time_step_init_scheme="random",
|
||||
time_step_floor=1e-4,
|
||||
rescale_prenorm_residual=False,
|
||||
use_cache=True,
|
||||
use_falcon_mambapy=False,
|
||||
mixer_rms_eps=1e-6,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=hidden_size,
|
||||
state_size=state_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
expand=expand,
|
||||
conv_kernel=conv_kernel,
|
||||
use_bias=use_bias,
|
||||
use_conv_bias=use_conv_bias,
|
||||
hidden_act=hidden_act,
|
||||
initializer_range=initializer_range,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
time_step_rank=time_step_rank,
|
||||
time_step_scale=time_step_scale,
|
||||
time_step_min=time_step_min,
|
||||
time_step_max=time_step_max,
|
||||
time_step_init_scheme=time_step_init_scheme,
|
||||
time_step_floor=time_step_floor,
|
||||
rescale_prenorm_residual=rescale_prenorm_residual,
|
||||
use_cache=use_cache,
|
||||
use_falcon_mambapy=use_falcon_mambapy,
|
||||
**kwargs,
|
||||
)
|
||||
self.mixer_rms_eps = mixer_rms_eps
|
||||
|
||||
|
||||
class FalconMambaCache(MambaCache):
|
||||
pass
|
||||
|
||||
|
||||
def rms_forward(hidden_states, variance_epsilon=1e-6):
|
||||
"""
|
||||
Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
|
||||
leverage this in order to multiply the final result with the RMSNorm weight
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Hidden states to normalize
|
||||
variance_epsilon (`float`):
|
||||
The eps value to add in the square root scaling factor
|
||||
"""
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class FalconMambaMixer(MambaMixer):
|
||||
def warn_slow_implementation(self):
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
if not is_fast_path_available:
|
||||
if self.use_falcon_mambapy:
|
||||
if is_mambapy_available():
|
||||
logger.warning_once(
|
||||
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
|
||||
" is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and"
|
||||
" https://github.com/Dao-AILab/causal-conv1d"
|
||||
)
|
||||
else:
|
||||
raise ImportError(
|
||||
"use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
|
||||
)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
|
||||
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
|
||||
" https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
|
||||
)
|
||||
|
||||
def __init__(self, config: FalconMambaConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
# Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
|
||||
self.register_buffer(
|
||||
"b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False
|
||||
)
|
||||
self.rms_eps = config.mixer_rms_eps
|
||||
|
||||
def cuda_kernels_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache_params: Optional[FalconMambaCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states).transpose(1, 2)
|
||||
|
||||
if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
|
||||
contextualized_states = mamba_inner_fn(
|
||||
projected_states,
|
||||
self.conv1d.weight,
|
||||
self.conv1d.bias if self.use_conv_bias else None,
|
||||
self.x_proj.weight,
|
||||
self.dt_proj.weight,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias.float() if self.use_bias else None,
|
||||
-torch.exp(self.A_log.float()),
|
||||
None, # input-dependent B
|
||||
None, # input-dependent C
|
||||
self.D.float(),
|
||||
delta_bias=self.dt_proj.bias.float(),
|
||||
delta_softplus=True,
|
||||
b_rms_weight=self.b_c_rms,
|
||||
c_rms_weight=self.b_c_rms,
|
||||
dt_rms_weight=self.dt_rms,
|
||||
b_c_dt_rms_eps=self.rms_eps,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||
|
||||
if attention_mask is not None:
|
||||
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
|
||||
if cache_params is not None and cache_position[0] > 0:
|
||||
hidden_states = causal_conv1d_update(
|
||||
hidden_states.squeeze(-1),
|
||||
cache_params.conv_states[self.layer_idx],
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
hidden_states = hidden_states.unsqueeze(-1)
|
||||
else:
|
||||
if cache_params is not None:
|
||||
conv_states = nn.functional.pad(
|
||||
hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
|
||||
)
|
||||
cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
|
||||
hidden_states = causal_conv1d_fn(
|
||||
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
# 3.a. input varying initialization of time_step, B and C
|
||||
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
||||
time_step, B, C = torch.split(
|
||||
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
|
||||
)
|
||||
|
||||
B = rms_forward(B, variance_epsilon=self.rms_eps)
|
||||
C = rms_forward(C, variance_epsilon=self.rms_eps)
|
||||
time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
|
||||
|
||||
# In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
|
||||
# at the price of a small overhead.
|
||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||
discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
|
||||
else:
|
||||
discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
|
||||
|
||||
A = -torch.exp(self.A_log.float())
|
||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
||||
time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
|
||||
if cache_params is not None and cache_position[0] > 0:
|
||||
scan_outputs = selective_state_update(
|
||||
cache_params.ssm_states[self.layer_idx],
|
||||
hidden_states[..., 0],
|
||||
discrete_time_step[..., 0],
|
||||
A,
|
||||
B[:, 0],
|
||||
C[:, 0],
|
||||
self.D,
|
||||
gate[..., 0],
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
).unsqueeze(-1)
|
||||
else:
|
||||
scan_outputs, ssm_state = selective_scan_fn(
|
||||
hidden_states,
|
||||
discrete_time_step,
|
||||
A,
|
||||
B.transpose(1, 2),
|
||||
C.transpose(1, 2),
|
||||
self.D.float(),
|
||||
gate,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
return_last_state=True,
|
||||
)
|
||||
if ssm_state is not None and cache_params is not None:
|
||||
cache_params.update_ssm_state(self.layer_idx, ssm_state)
|
||||
|
||||
# 4. Final linear projection
|
||||
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
|
||||
return contextualized_states
|
||||
|
||||
def slow_forward(
|
||||
self,
|
||||
input_states,
|
||||
cache_params: Optional[FalconMambaCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
batch_size, seq_len, _ = input_states.shape
|
||||
dtype = input_states.dtype
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
|
||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||
|
||||
if attention_mask is not None:
|
||||
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
if cache_params is not None:
|
||||
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
|
||||
ssm_state = ssm_state.to(hidden_states.device)
|
||||
# use `cache_position.shape[0]` to check whether we are in prefill
|
||||
# stage, it's equivalent to check `cache_position[0] == 0`, which
|
||||
# breaks dynamo fullgraph constraints
|
||||
if cache_position is not None and cache_position.shape[0] == self.conv_kernel_size:
|
||||
conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
|
||||
|
||||
cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
|
||||
hidden_states = self.act(
|
||||
self.conv1d(hidden_states)[..., :seq_len]
|
||||
) # [batch, intermediate_size, seq_len]
|
||||
else:
|
||||
conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
|
||||
conv_state = conv_state.to(self.conv1d.weight.device)
|
||||
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
|
||||
if self.use_conv_bias:
|
||||
hidden_states += self.conv1d.bias
|
||||
hidden_states = (
|
||||
self.act(hidden_states).to(dtype).unsqueeze(-1)
|
||||
) # [batch, intermediate_size, 1] : decoding
|
||||
else:
|
||||
ssm_state = torch.zeros(
|
||||
(batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype
|
||||
)
|
||||
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
|
||||
|
||||
if attention_mask is not None:
|
||||
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
|
||||
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
||||
time_step, B, C = torch.split(
|
||||
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
|
||||
)
|
||||
|
||||
B = rms_forward(B, variance_epsilon=self.rms_eps)
|
||||
C = rms_forward(C, variance_epsilon=self.rms_eps)
|
||||
time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
|
||||
|
||||
discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
|
||||
discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(
|
||||
1, 2
|
||||
) # [batch, intermediate_size, seq_len]
|
||||
|
||||
# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
|
||||
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
|
||||
discrete_A = torch.exp(
|
||||
A[None, :, None, :] * discrete_time_step[:, :, :, None]
|
||||
) # [batch, intermediate_size, seq_len, ssm_state_size]
|
||||
discrete_B = (
|
||||
discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
|
||||
) # [batch, intermediate_size, seq_len, ssm_state_size]
|
||||
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
|
||||
|
||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
||||
if self.use_falcon_mambapy and self.training and cache_params is None:
|
||||
hs = pscan(
|
||||
discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)
|
||||
) # [batch, seq_len, intermediate_size, ssm_state_size]
|
||||
scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
|
||||
scan_output = scan_output + hidden_states * self.D[None, :, None]
|
||||
scan_output = scan_output * self.act(gate)
|
||||
else:
|
||||
scan_outputs = []
|
||||
for i in range(seq_len):
|
||||
ssm_state = (
|
||||
discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
|
||||
) # [batch, intermediate_size, ssm_state]
|
||||
scan_output = torch.matmul(
|
||||
ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)
|
||||
) # [batch, intermediate_size, 1]
|
||||
scan_outputs.append(scan_output[:, :, 0])
|
||||
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
|
||||
scan_output = scan_output + (hidden_states * self.D[None, :, None])
|
||||
scan_output = scan_output * self.act(gate)
|
||||
|
||||
if cache_params is not None:
|
||||
cache_params.update_ssm_state(self.layer_idx, ssm_state)
|
||||
|
||||
# 4. Final linear projection
|
||||
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
|
||||
return contextualized_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cache_params: Optional[FalconMambaCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
|
||||
|
||||
class FalconMambaRMSNorm(MambaRMSNorm):
|
||||
def forward(self, hidden_states):
|
||||
return self.weight.to(hidden_states.device) * rms_forward(
|
||||
hidden_states, variance_epsilon=self.variance_epsilon
|
||||
)
|
||||
|
||||
|
||||
class FalconMambaBlock(MambaBlock):
|
||||
pass
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class FalconMambaPreTrainedModel(MambaPreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
class FalconMambaOutput(MambaOutput):
|
||||
pass
|
||||
|
||||
|
||||
class FalconMambaCausalLMOutput(MambaCausalLMOutput):
|
||||
pass
|
||||
|
||||
|
||||
class FalconMambaModel(MambaModel, FalconMambaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
FalconMambaPreTrainedModel.__init__(config)
|
||||
|
||||
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList(
|
||||
[FalconMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def load_hook(self, state_dict, prefix, *args):
|
||||
raise AttributeError("Not needed for FalconMamba")
|
||||
|
||||
|
||||
class FalconMambaForCausalLM(MambaForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FalconMambaForCausalLM",
|
||||
"FalconMambaModel",
|
||||
"FalconMambaPreTrainedModel",
|
||||
"FalconMambaCache",
|
||||
"FalconMambaConfig",
|
||||
]
|
||||
@@ -24,7 +24,7 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import MambaCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
@@ -55,9 +55,106 @@ if is_causal_conv1d_available():
|
||||
else:
|
||||
causal_conv1d_update, causal_conv1d_fn = None, None
|
||||
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
|
||||
class MambaCache:
|
||||
"""
|
||||
Cache for mamba model which does not have attention mechanism and key value states.
|
||||
|
||||
Arguments:
|
||||
config (`PretrainedConfig):
|
||||
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||||
max_batch_size (`int`):
|
||||
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used.
|
||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
device (`torch.device` or `str`, *optional*):
|
||||
The device on which the cache should be initialized. Should be the same as the layer.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
|
||||
|
||||
>>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
|
||||
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
|
||||
|
||||
>>> # Prepare a cache class and pass it to model's forward
|
||||
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
|
||||
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||
>>> outputs.past_key_values
|
||||
MambaCache()
|
||||
```
|
||||
"""
|
||||
|
||||
is_compileable = True
|
||||
|
||||
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
max_batch_size: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: Union[torch.device, str, None] = None,
|
||||
):
|
||||
self.max_batch_size = max_batch_size
|
||||
self._dtype = dtype
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.ssm_state_size = config.state_size
|
||||
self.conv_kernel_size = config.conv_kernel
|
||||
|
||||
self.conv_states: list[torch.Tensor] = []
|
||||
self.ssm_states: list[torch.Tensor] = []
|
||||
device = torch.device(device) if device is not None else None
|
||||
for _ in range(config.num_hidden_layers):
|
||||
conv_state: torch.Tensor = torch.zeros(
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.conv_kernel_size,
|
||||
device=device,
|
||||
dtype=self._dtype,
|
||||
)
|
||||
ssm_state: torch.Tensor = torch.zeros(
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.ssm_state_size,
|
||||
device=device,
|
||||
dtype=self._dtype,
|
||||
)
|
||||
|
||||
torch._dynamo.mark_static_address(conv_state)
|
||||
torch._dynamo.mark_static_address(ssm_state)
|
||||
self.conv_states.append(conv_state)
|
||||
self.ssm_states.append(ssm_state)
|
||||
|
||||
def update_conv_state(
|
||||
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
|
||||
) -> torch.Tensor:
|
||||
# This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
|
||||
# when the cache is initialized in the forward pass (e.g. Mamba)
|
||||
if self.conv_states[layer_idx].device != new_conv_state.device:
|
||||
self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
|
||||
|
||||
conv_state = self.conv_states[layer_idx]
|
||||
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
|
||||
|
||||
conv_state = conv_state.roll(shifts=-1, dims=-1)
|
||||
conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
|
||||
self.conv_states[layer_idx].zero_()
|
||||
self.conv_states[layer_idx] += conv_state
|
||||
return self.conv_states[layer_idx]
|
||||
|
||||
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
|
||||
self.ssm_states[layer_idx].zero_()
|
||||
self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device)
|
||||
return self.ssm_states[layer_idx]
|
||||
|
||||
def reset(self):
|
||||
for layer_idx in range(len(self.conv_states)):
|
||||
# In-place ops prevent breaking the static address
|
||||
self.conv_states[layer_idx].zero_()
|
||||
self.ssm_states[layer_idx].zero_()
|
||||
|
||||
|
||||
class MambaMixer(nn.Module):
|
||||
@@ -109,6 +206,12 @@ class MambaMixer(nn.Module):
|
||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
|
||||
self.use_bias = config.use_bias
|
||||
|
||||
self.warn_slow_implementation()
|
||||
|
||||
def warn_slow_implementation(self):
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
if not is_fast_path_available:
|
||||
if self.use_mambapy:
|
||||
if is_mambapy_available():
|
||||
@@ -319,6 +422,9 @@ class MambaMixer(nn.Module):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
@@ -650,32 +756,26 @@ class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin):
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- uses `cache_params` as opposed to `past_key_values`
|
||||
|
||||
if use_cache:
|
||||
# `cache_position` should have been initialized in `generate`
|
||||
if cache_position is None:
|
||||
raise ValueError(
|
||||
"`cache_position` should not be None as it should have been initialized in "
|
||||
"`model.generate`, you are responsible for passing in a valid `cache_position` if "
|
||||
"you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
|
||||
)
|
||||
if cache_position[0] > 0:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = None
|
||||
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
if use_cache and cache_params is None:
|
||||
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
|
||||
# considering padding will be applied when input length is shorter, and truncation
|
||||
# will be applied when it is longer, so it will be equivalent to always have it match
|
||||
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
|
||||
cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device)
|
||||
if inputs_embeds is not None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
max_batch_size = inputs_embeds.size(0)
|
||||
else:
|
||||
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
|
||||
# considering padding will be applied when input length is shorter, and truncation
|
||||
# will be applied when it is longer, so it will be equivalent to always have it match
|
||||
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
|
||||
cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
|
||||
max_batch_size = input_ids.size(0)
|
||||
cache_params = MambaCache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype)
|
||||
|
||||
if inputs_embeds is not None and cache_params is None:
|
||||
if use_cache and cache_position[0] > 0:
|
||||
model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous()
|
||||
attention_mask = None
|
||||
|
||||
if not use_cache and inputs_embeds is not None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
@@ -751,4 +851,4 @@ class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel"]
|
||||
__all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel", "MambaCache"]
|
||||
|
||||
@@ -970,38 +970,33 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- uses `cache_params` as opposed to `past_key_values`
|
||||
|
||||
if use_cache:
|
||||
# `cache_position` should have been initialized in `generate`
|
||||
if cache_position is None:
|
||||
raise ValueError(
|
||||
"`cache_position` should not be None as it should have been initialized in "
|
||||
"`model.generate`, you are responsible for passing in a valid `cache_position` if "
|
||||
"you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
|
||||
)
|
||||
if cache_position[0] > 0:
|
||||
input_ids = input_ids[:, -1][..., None]
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = None
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
if use_cache and cache_params is None:
|
||||
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
|
||||
# considering padding will be applied when input length is shorter, and truncation
|
||||
# will be applied when it is longer, so it will be equivalent to always have it match
|
||||
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
|
||||
cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device)
|
||||
if inputs_embeds is not None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
max_batch_size = inputs_embeds.size(0)
|
||||
else:
|
||||
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
|
||||
# considering padding will be applied when input length is shorter, and truncation
|
||||
# will be applied when it is longer, so it will be equivalent to always have it match
|
||||
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
|
||||
cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
|
||||
max_batch_size = input_ids.size(0)
|
||||
cache_params = Mamba2Cache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype)
|
||||
|
||||
if inputs_embeds is not None and cache_params is None:
|
||||
if use_cache and cache_position[0] > 0:
|
||||
model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous()
|
||||
attention_mask = None
|
||||
|
||||
if not use_cache and inputs_embeds is not None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"attention_mask": attention_mask,
|
||||
"cache_params": cache_params,
|
||||
"use_cache": use_cache,
|
||||
"cache_position": cache_position,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@@ -44,13 +44,6 @@ class HybridCache(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MambaCache(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class OffloadedCache(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ from transformers.testing_utils import (
|
||||
require_torch_accelerator,
|
||||
require_torch_large_accelerator,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -41,10 +40,10 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
FalconMambaCache,
|
||||
FalconMambaForCausalLM,
|
||||
FalconMambaModel,
|
||||
)
|
||||
from transformers.cache_utils import MambaCache
|
||||
|
||||
|
||||
# Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba
|
||||
@@ -312,31 +311,6 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# some params shouldn't be scattered by nn.DataParallel
|
||||
# so just remove them if they are present.
|
||||
blacklist_non_batched_params = ["cache_params"]
|
||||
for k in blacklist_non_batched_params:
|
||||
inputs_dict.pop(k, None)
|
||||
|
||||
# move input tensors to cuda:O
|
||||
for k, v in inputs_dict.items():
|
||||
if torch.is_tensor(v):
|
||||
inputs_dict[k] = v.to(0)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
model.to(0)
|
||||
model.eval()
|
||||
|
||||
# Wrap model in nn.DataParallel
|
||||
model = torch.nn.DataParallel(model)
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
def test_falcon_mamba_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_falcon_mamba_model(*config_and_inputs)
|
||||
@@ -411,7 +385,7 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, MambaCache): # MODIFIED PART START
|
||||
if isinstance(tuple_object, FalconMambaCache): # MODIFIED PART START
|
||||
recursive_check(tuple_object.conv_states, dict_object.conv_states)
|
||||
recursive_check(tuple_object.ssm_states, dict_object.ssm_states)
|
||||
elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END
|
||||
@@ -458,6 +432,10 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
@unittest.skip("Mamba models do not support DDP.")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@@ -497,7 +475,9 @@ class FalconMambaIntegrationTests(unittest.TestCase):
|
||||
@require_bitsandbytes
|
||||
def test_generation_4bit(self):
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
@@ -513,6 +493,7 @@ class FalconMambaIntegrationTests(unittest.TestCase):
|
||||
|
||||
inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
print(self.tokenizer.batch_decode(out, skip_special_tokens=False)[0])
|
||||
|
||||
self.assertEqual(
|
||||
self.tokenizer.batch_decode(out, skip_special_tokens=False)[0],
|
||||
@@ -543,7 +524,7 @@ class FalconMambaIntegrationTests(unittest.TestCase):
|
||||
inputs = tok(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(torch_device)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=0, torch_dtype=torch.float16)
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=20)
|
||||
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
out = tok.batch_decode(out, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(out, EXPECTED_OUTPUT)
|
||||
@@ -553,7 +534,7 @@ class FalconMambaIntegrationTests(unittest.TestCase):
|
||||
inputs_embeds = model.get_input_embeddings()(inputs.pop("input_ids"))
|
||||
|
||||
inputs["inputs_embeds"] = inputs_embeds
|
||||
out = model.generate(**inputs, max_new_tokens=20)
|
||||
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
out = tok.batch_decode(out, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_OUTPUTS = Expectations(
|
||||
|
||||
@@ -20,7 +20,7 @@ from unittest.util import safe_repr
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, MambaConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -32,10 +32,10 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MambaCache,
|
||||
MambaForCausalLM,
|
||||
MambaModel,
|
||||
)
|
||||
from transformers.models.mamba.modeling_mamba import MambaCache
|
||||
|
||||
|
||||
class MambaModelTester:
|
||||
@@ -279,31 +279,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# some params shouldn't be scattered by nn.DataParallel
|
||||
# so just remove them if they are present.
|
||||
blacklist_non_batched_params = ["cache_params"]
|
||||
for k in blacklist_non_batched_params:
|
||||
inputs_dict.pop(k, None)
|
||||
|
||||
# move input tensors to cuda:O
|
||||
for k, v in inputs_dict.items():
|
||||
if torch.is_tensor(v):
|
||||
inputs_dict[k] = v.to(0)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
model.to(0)
|
||||
model.eval()
|
||||
|
||||
# Wrap model in nn.DataParallel
|
||||
model = torch.nn.DataParallel(model)
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
def test_mamba_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mamba_model(*config_and_inputs)
|
||||
@@ -452,6 +427,10 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size),
|
||||
)
|
||||
|
||||
@unittest.skip("Mamba models do not support DDP.")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class MambaIntegrationTests(unittest.TestCase):
|
||||
@@ -547,11 +526,11 @@ class MambaIntegrationTests(unittest.TestCase):
|
||||
torch_device
|
||||
)
|
||||
|
||||
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
|
||||
output = model.generate(input_ids, max_new_tokens=20)
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
self.assertEqual(output_sentence, expected_output)
|
||||
|
||||
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
|
||||
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
|
||||
output = model.generate(input_ids, max_new_tokens=20)
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
self.assertEqual(output_sentence, expected_output)
|
||||
|
||||
@@ -62,8 +62,6 @@ if is_torch_available():
|
||||
TEST_CACHE_IMPLEMENTATIONS = [
|
||||
cache_name
|
||||
for cache_name in ALL_CACHE_IMPLEMENTATIONS
|
||||
# TODO (joao): Mamba is not compatible with most models, remove from `ALL_CACHE_IMPLEMENTATIONS`?
|
||||
if cache_name != "mamba"
|
||||
# TODO (joao): offloaded_hybrid == offloaded_hybrid_chunked, deprecate one of them
|
||||
if cache_name != "offloaded_hybrid"
|
||||
]
|
||||
|
||||
@@ -141,7 +141,12 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
return updated_node
|
||||
|
||||
def leave_ImportFrom(self, original_node, updated_node):
|
||||
"""The imports from other file types (configuration, processing etc) should use original model name."""
|
||||
"""
|
||||
The imports from other file types (configuration, processing etc) should use original model name.
|
||||
Also, no replaces on absolute imports (e.g. `from mamba_ssm import ...`)
|
||||
"""
|
||||
if len(original_node.relative) == 0: # no replaces on absolute imports
|
||||
return original_node
|
||||
if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()):
|
||||
patterns = "|".join(ALL_FILE_TYPES)
|
||||
regex = rf"({patterns})_{self.new_name}"
|
||||
|
||||
Reference in New Issue
Block a user