Add padding-free to bamba (#35861)

* add seq_idx and fa kwargs

* update tests

* docs and grad ckpt support

* fmt

* better names

* test_raise_missing_padding_free_kwarg_errs

* + seq_idx in doc strings

* padding free training docs

* add link to pr plots

* raise err on attn_mask with padding free

* rm raising missing padding free err test

* BambaFlashAttentionKwargs

* run modular util for modular_granitemoehybrid.py
This commit is contained in:
Garrett Goon
2025-05-20 11:13:59 -04:00
committed by GitHub
parent 2a79471318
commit 390f153469
5 changed files with 233 additions and 25 deletions

View File

@@ -39,7 +39,7 @@ Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-mod
<!--- <!---
## Usage Tips ## Usage Tips
Tips: Tips:
- The architecture is based on Mamba-2 models. - The architecture is based on Mamba-2 models.
@@ -63,7 +63,35 @@ response = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0]) print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
``` ```
## Padding-Free Training
Bamba supports padding-free training in which distinct training examples can be concatenated
together while nevertheless processing the inputs as though they belonged to separate batches. When
the examples are of varying lengths, padding-free training can provide significant speed ups and
memory savings compared to batching the examples together and using padding, as the unnecessary
compute and memory due to padding is avoided entirely. The performance gains depend on factors such
as the model and the data distribution, but throughput gains up to [~2x are commonly
seen](https://github.com/huggingface/transformers/pull/35861#issue-2807873129).
Using padding-free training with Bamba requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d`
packages, and the following arguments must be passed to the model in addition to `input_ids` and
`labels`:
* `position_ids: torch.LongTensor`: the position index of each token in each sequence.
* `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
* Each of the [`FlashAttentionKwargs`]
* `cu_seq_lens_q: torch.LongTensor`: The cumulative sequence lengths of all queries.
* `cu_seq_lens_k: torch.LongTensor`: The cumulative sequence lengths of all keys.
* `max_length_q: int`: the longest query length in the batch.
* `max_length_k: int`: the longest key length in the batch.
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] can be used
to programmatically generate the above set of additional arguments using `return_seq_idx=True` and
`return_flash_attn_kwargs=True`. See [this blog post](https://huggingface.co/blog/packing-with-FA2)
for additional information.
[[autodoc]] BambaForCausalLM [[autodoc]] BambaForCausalLM
- forward - forward
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim). This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).

View File

@@ -24,7 +24,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, Optional, Tuple, Union from functools import partial
from typing import Callable, Optional, Tuple, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
@@ -61,6 +62,31 @@ else:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class BambaFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.
Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""
cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer # Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache): class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
""" """
@@ -487,6 +513,7 @@ class BambaMixer(nn.Module):
cache_params: Optional[HybridMambaAttentionDynamicCache] = None, cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
): ):
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
@@ -569,7 +596,7 @@ class BambaMixer(nn.Module):
A, A,
D=self.D, D=self.D,
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
seq_idx=None, # was seq_idx seq_idx=seq_idx,
activation=self.activation, activation=self.activation,
rmsnorm_weight=self.norm.weight, rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon, rmsnorm_eps=self.norm.variance_epsilon,
@@ -610,6 +637,7 @@ class BambaMixer(nn.Module):
weight=self.conv1d.weight.squeeze(1), weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias, bias=self.conv1d.bias,
activation=self.activation, activation=self.activation,
seq_idx=seq_idx,
).transpose(1, 2) ).transpose(1, 2)
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
@@ -629,7 +657,7 @@ class BambaMixer(nn.Module):
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
D=self.D, D=self.D,
z=None, z=None,
seq_idx=None, seq_idx=seq_idx,
return_final_states=True, return_final_states=True,
dt_bias=self.dt_bias, dt_bias=self.dt_bias,
dt_softplus=True, dt_softplus=True,
@@ -863,9 +891,15 @@ class BambaMixer(nn.Module):
cache_params: Optional[HybridMambaAttentionDynamicCache] = None, cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
**kwargs,
): ):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
if seq_idx is not None:
raise NotImplementedError(
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
)
dtype = hidden_states.dtype dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
@@ -939,7 +973,7 @@ class BambaDecoderLayer(nn.Module):
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs, **kwargs: Unpack[BambaFlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -959,8 +993,8 @@ class BambaDecoderLayer(nn.Module):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head. with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
into the model padding-free training and/or improve torch.compile performance.
""" """
residual = hidden_states residual = hidden_states
@@ -974,6 +1008,7 @@ class BambaDecoderLayer(nn.Module):
cache_params=past_key_value, cache_params=past_key_value,
cache_position=cache_position, cache_position=cache_position,
attention_mask=attention_mask, attention_mask=attention_mask,
**kwargs,
) )
self_attn_weights = None self_attn_weights = None
elif self.layer_type == "attention": elif self.layer_type == "attention":
@@ -1076,7 +1111,7 @@ class BambaModel(BambaPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now **kwargs: Unpack[BambaFlashAttentionKwargs],
) -> BaseModelOutputWithPast: ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@@ -1128,7 +1163,7 @@ class BambaModel(BambaPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **kwargs),
hidden_states, hidden_states,
layer_mask, layer_mask,
position_ids, position_ids,
@@ -1148,6 +1183,7 @@ class BambaModel(BambaPreTrainedModel):
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
**kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@@ -19,7 +19,8 @@
# limitations under the License. # limitations under the License.
"""PyTorch Bamba model.""" """PyTorch Bamba model."""
from typing import Optional, Tuple, Union from functools import partial
from typing import Optional, Tuple, TypedDict, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@@ -46,7 +47,12 @@ from transformers.models.mamba2.modeling_mamba2 import (
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, can_return_tuple, logging from ...processing_utils import Unpack
from ...utils import (
auto_docstring,
can_return_tuple,
logging,
)
from ...utils.import_utils import is_causal_conv1d_available, is_flash_attn_2_available, is_mamba_2_ssm_available from ...utils.import_utils import is_causal_conv1d_available, is_flash_attn_2_available, is_mamba_2_ssm_available
from .configuration_bamba import BambaConfig from .configuration_bamba import BambaConfig
@@ -71,6 +77,31 @@ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_c
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class BambaFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.
Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""
cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer # Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache): class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
""" """
@@ -278,6 +309,7 @@ class BambaMixer(nn.Module):
cache_params: Optional[HybridMambaAttentionDynamicCache] = None, cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
): ):
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
@@ -360,7 +392,7 @@ class BambaMixer(nn.Module):
A, A,
D=self.D, D=self.D,
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
seq_idx=None, # was seq_idx seq_idx=seq_idx,
activation=self.activation, activation=self.activation,
rmsnorm_weight=self.norm.weight, rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon, rmsnorm_eps=self.norm.variance_epsilon,
@@ -401,6 +433,7 @@ class BambaMixer(nn.Module):
weight=self.conv1d.weight.squeeze(1), weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias, bias=self.conv1d.bias,
activation=self.activation, activation=self.activation,
seq_idx=seq_idx,
).transpose(1, 2) ).transpose(1, 2)
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
@@ -420,7 +453,7 @@ class BambaMixer(nn.Module):
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
D=self.D, D=self.D,
z=None, z=None,
seq_idx=None, seq_idx=seq_idx,
return_final_states=True, return_final_states=True,
dt_bias=self.dt_bias, dt_bias=self.dt_bias,
dt_softplus=True, dt_softplus=True,
@@ -654,9 +687,15 @@ class BambaMixer(nn.Module):
cache_params: Optional[HybridMambaAttentionDynamicCache] = None, cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
**kwargs,
): ):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
if seq_idx is not None:
raise NotImplementedError(
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
)
dtype = hidden_states.dtype dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
@@ -701,7 +740,7 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs, **kwargs: Unpack[BambaFlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -721,8 +760,8 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head. with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
into the model padding-free training and/or improve torch.compile performance.
""" """
residual = hidden_states residual = hidden_states
@@ -736,6 +775,7 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
cache_params=past_key_value, cache_params=past_key_value,
cache_position=cache_position, cache_position=cache_position,
attention_mask=attention_mask, attention_mask=attention_mask,
**kwargs,
) )
self_attn_weights = None self_attn_weights = None
elif self.layer_type == "attention": elif self.layer_type == "attention":
@@ -838,7 +878,7 @@ class BambaModel(BambaPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now **kwargs: Unpack[BambaFlashAttentionKwargs],
) -> BaseModelOutputWithPast: ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@@ -890,7 +930,7 @@ class BambaModel(BambaPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, partial(decoder_layer.__call__, **kwargs),
hidden_states, hidden_states,
layer_mask, layer_mask,
position_ids, position_ids,
@@ -910,6 +950,7 @@ class BambaModel(BambaPreTrainedModel):
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
**kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@@ -439,6 +439,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
cache_params: Optional[HybridMambaAttentionDynamicCache] = None, cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
): ):
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
@@ -521,7 +522,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
A, A,
D=self.D, D=self.D,
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
seq_idx=None, # was seq_idx seq_idx=seq_idx,
activation=self.activation, activation=self.activation,
rmsnorm_weight=self.norm.weight, rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon, rmsnorm_eps=self.norm.variance_epsilon,
@@ -562,6 +563,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
weight=self.conv1d.weight.squeeze(1), weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias, bias=self.conv1d.bias,
activation=self.activation, activation=self.activation,
seq_idx=seq_idx,
).transpose(1, 2) ).transpose(1, 2)
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
@@ -581,7 +583,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
D=self.D, D=self.D,
z=None, z=None,
seq_idx=None, seq_idx=seq_idx,
return_final_states=True, return_final_states=True,
dt_bias=self.dt_bias, dt_bias=self.dt_bias,
dt_softplus=True, dt_softplus=True,
@@ -815,9 +817,15 @@ class GraniteMoeHybridMambaLayer(nn.Module):
cache_params: Optional[HybridMambaAttentionDynamicCache] = None, cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
**kwargs,
): ):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
if seq_idx is not None:
raise NotImplementedError(
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
)
dtype = hidden_states.dtype dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66

View File

@@ -14,16 +14,25 @@
"""Testing suite for the PyTorch Bamba model.""" """Testing suite for the PyTorch Bamba model."""
import inspect import inspect
import tempfile
import unittest import unittest
import pytest import pytest
from pytest import mark
from transformers import AutoTokenizer, BambaConfig, is_torch_available from transformers import (
AutoTokenizer,
BambaConfig,
DataCollatorWithFlattening,
is_torch_available,
)
from transformers.testing_utils import ( from transformers.testing_utils import (
Expectations, Expectations,
require_deterministic_for_xpu, require_deterministic_for_xpu,
require_flash_attn,
require_torch, require_torch,
require_torch_accelerator, require_torch_accelerator,
require_torch_gpu,
slow, slow,
torch_device, torch_device,
) )
@@ -489,6 +498,92 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# They should result in very similar logits # They should result in very similar logits
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
@unittest.skip(
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip(
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
pass
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_idx_and_fa_kwargs(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
if "position_ids" not in inspect.signature(model.forward).parameters:
self.skipTest("Model does not support position_ids")
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
model = (
model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
)
.to(torch_device)
.eval()
)
# flatten
features = [
{"input_ids": i[a.bool()].tolist()}
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
]
# add position_ids + fa_kwargs + seq_idx
data_collator = DataCollatorWithFlattening(
return_tensors="pt", return_seq_idx=True, return_flash_attn_kwargs=True
)
batch = data_collator(features)
batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()}
res_padded = model(**inputs_dict)
res_padfree = model(**batch_cuda)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@slow @slow
@require_torch @require_torch