Add padding-free to bamba (#35861)
* add seq_idx and fa kwargs * update tests * docs and grad ckpt support * fmt * better names * test_raise_missing_padding_free_kwarg_errs * + seq_idx in doc strings * padding free training docs * add link to pr plots * raise err on attn_mask with padding free * rm raising missing padding free err test * BambaFlashAttentionKwargs * run modular util for modular_granitemoehybrid.py
This commit is contained in:
@@ -39,7 +39,7 @@ Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-mod
|
|||||||
<!---
|
<!---
|
||||||
## Usage Tips
|
## Usage Tips
|
||||||
|
|
||||||
Tips:
|
Tips:
|
||||||
|
|
||||||
- The architecture is based on Mamba-2 models.
|
- The architecture is based on Mamba-2 models.
|
||||||
|
|
||||||
@@ -63,7 +63,35 @@ response = model.generate(**inputs, max_new_tokens=64)
|
|||||||
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
|
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Padding-Free Training
|
||||||
|
|
||||||
|
Bamba supports padding-free training in which distinct training examples can be concatenated
|
||||||
|
together while nevertheless processing the inputs as though they belonged to separate batches. When
|
||||||
|
the examples are of varying lengths, padding-free training can provide significant speed ups and
|
||||||
|
memory savings compared to batching the examples together and using padding, as the unnecessary
|
||||||
|
compute and memory due to padding is avoided entirely. The performance gains depend on factors such
|
||||||
|
as the model and the data distribution, but throughput gains up to [~2x are commonly
|
||||||
|
seen](https://github.com/huggingface/transformers/pull/35861#issue-2807873129).
|
||||||
|
|
||||||
|
Using padding-free training with Bamba requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d`
|
||||||
|
packages, and the following arguments must be passed to the model in addition to `input_ids` and
|
||||||
|
`labels`:
|
||||||
|
* `position_ids: torch.LongTensor`: the position index of each token in each sequence.
|
||||||
|
* `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
|
||||||
|
* Each of the [`FlashAttentionKwargs`]
|
||||||
|
* `cu_seq_lens_q: torch.LongTensor`: The cumulative sequence lengths of all queries.
|
||||||
|
* `cu_seq_lens_k: torch.LongTensor`: The cumulative sequence lengths of all keys.
|
||||||
|
* `max_length_q: int`: the longest query length in the batch.
|
||||||
|
* `max_length_k: int`: the longest key length in the batch.
|
||||||
|
|
||||||
|
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] can be used
|
||||||
|
to programmatically generate the above set of additional arguments using `return_seq_idx=True` and
|
||||||
|
`return_flash_attn_kwargs=True`. See [this blog post](https://huggingface.co/blog/packing-with-FA2)
|
||||||
|
for additional information.
|
||||||
|
|
||||||
|
|
||||||
[[autodoc]] BambaForCausalLM
|
[[autodoc]] BambaForCausalLM
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
|
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
|
||||||
|
|||||||
@@ -24,7 +24,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from functools import partial
|
||||||
|
from typing import Callable, Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -61,6 +62,31 @@ else:
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BambaFlashAttentionKwargs(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
||||||
|
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
cu_seq_lens_q (`torch.LongTensor`)
|
||||||
|
Gets cumulative sequence length for query state.
|
||||||
|
cu_seq_lens_k (`torch.LongTensor`)
|
||||||
|
Gets cumulative sequence length for key state.
|
||||||
|
max_length_q (`int`):
|
||||||
|
Maximum sequence length for query state.
|
||||||
|
max_length_k (`int`):
|
||||||
|
Maximum sequence length for key state.
|
||||||
|
seq_idx (`torch.IntTensor):
|
||||||
|
Index of each packed sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cu_seq_lens_q: torch.LongTensor
|
||||||
|
cu_seq_lens_k: torch.LongTensor
|
||||||
|
max_length_q: int
|
||||||
|
max_length_k: int
|
||||||
|
seq_idx: torch.IntTensor
|
||||||
|
|
||||||
|
|
||||||
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
|
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
|
||||||
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
|
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
|
||||||
"""
|
"""
|
||||||
@@ -487,6 +513,7 @@ class BambaMixer(nn.Module):
|
|||||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
seq_idx: Optional[torch.IntTensor] = None,
|
||||||
):
|
):
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||||
@@ -569,7 +596,7 @@ class BambaMixer(nn.Module):
|
|||||||
A,
|
A,
|
||||||
D=self.D,
|
D=self.D,
|
||||||
chunk_size=self.chunk_size,
|
chunk_size=self.chunk_size,
|
||||||
seq_idx=None, # was seq_idx
|
seq_idx=seq_idx,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
rmsnorm_weight=self.norm.weight,
|
rmsnorm_weight=self.norm.weight,
|
||||||
rmsnorm_eps=self.norm.variance_epsilon,
|
rmsnorm_eps=self.norm.variance_epsilon,
|
||||||
@@ -610,6 +637,7 @@ class BambaMixer(nn.Module):
|
|||||||
weight=self.conv1d.weight.squeeze(1),
|
weight=self.conv1d.weight.squeeze(1),
|
||||||
bias=self.conv1d.bias,
|
bias=self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
|
seq_idx=seq_idx,
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
|
|
||||||
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
||||||
@@ -629,7 +657,7 @@ class BambaMixer(nn.Module):
|
|||||||
chunk_size=self.chunk_size,
|
chunk_size=self.chunk_size,
|
||||||
D=self.D,
|
D=self.D,
|
||||||
z=None,
|
z=None,
|
||||||
seq_idx=None,
|
seq_idx=seq_idx,
|
||||||
return_final_states=True,
|
return_final_states=True,
|
||||||
dt_bias=self.dt_bias,
|
dt_bias=self.dt_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
@@ -863,9 +891,15 @@ class BambaMixer(nn.Module):
|
|||||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
seq_idx: Optional[torch.IntTensor] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
||||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
|
||||||
|
if seq_idx is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
|
||||||
|
)
|
||||||
dtype = hidden_states.dtype
|
dtype = hidden_states.dtype
|
||||||
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
||||||
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
||||||
@@ -939,7 +973,7 @@ class BambaDecoderLayer(nn.Module):
|
|||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs: Unpack[BambaFlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -959,8 +993,8 @@ class BambaDecoderLayer(nn.Module):
|
|||||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||||
with `head_dim` being the embedding dimension of each attention head.
|
with `head_dim` being the embedding dimension of each attention head.
|
||||||
kwargs (`dict`, *optional*):
|
kwargs (`dict`, *optional*):
|
||||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
|
||||||
into the model
|
padding-free training and/or improve torch.compile performance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -974,6 +1008,7 @@ class BambaDecoderLayer(nn.Module):
|
|||||||
cache_params=past_key_value,
|
cache_params=past_key_value,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self_attn_weights = None
|
self_attn_weights = None
|
||||||
elif self.layer_type == "attention":
|
elif self.layer_type == "attention":
|
||||||
@@ -1076,7 +1111,7 @@ class BambaModel(BambaPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs, # NOOP kwargs, for now
|
**kwargs: Unpack[BambaFlashAttentionKwargs],
|
||||||
) -> BaseModelOutputWithPast:
|
) -> BaseModelOutputWithPast:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -1128,7 +1163,7 @@ class BambaModel(BambaPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_mask,
|
layer_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
@@ -1148,6 +1183,7 @@ class BambaModel(BambaPreTrainedModel):
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|||||||
@@ -19,7 +19,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch Bamba model."""
|
"""PyTorch Bamba model."""
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from functools import partial
|
||||||
|
from typing import Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@@ -46,7 +47,12 @@ from transformers.models.mamba2.modeling_mamba2 import (
|
|||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import auto_docstring, can_return_tuple, logging
|
from ...processing_utils import Unpack
|
||||||
|
from ...utils import (
|
||||||
|
auto_docstring,
|
||||||
|
can_return_tuple,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
from ...utils.import_utils import is_causal_conv1d_available, is_flash_attn_2_available, is_mamba_2_ssm_available
|
from ...utils.import_utils import is_causal_conv1d_available, is_flash_attn_2_available, is_mamba_2_ssm_available
|
||||||
from .configuration_bamba import BambaConfig
|
from .configuration_bamba import BambaConfig
|
||||||
|
|
||||||
@@ -71,6 +77,31 @@ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_c
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BambaFlashAttentionKwargs(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
||||||
|
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
cu_seq_lens_q (`torch.LongTensor`)
|
||||||
|
Gets cumulative sequence length for query state.
|
||||||
|
cu_seq_lens_k (`torch.LongTensor`)
|
||||||
|
Gets cumulative sequence length for key state.
|
||||||
|
max_length_q (`int`):
|
||||||
|
Maximum sequence length for query state.
|
||||||
|
max_length_k (`int`):
|
||||||
|
Maximum sequence length for key state.
|
||||||
|
seq_idx (`torch.IntTensor):
|
||||||
|
Index of each packed sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cu_seq_lens_q: torch.LongTensor
|
||||||
|
cu_seq_lens_k: torch.LongTensor
|
||||||
|
max_length_q: int
|
||||||
|
max_length_k: int
|
||||||
|
seq_idx: torch.IntTensor
|
||||||
|
|
||||||
|
|
||||||
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
|
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
|
||||||
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
|
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
|
||||||
"""
|
"""
|
||||||
@@ -278,6 +309,7 @@ class BambaMixer(nn.Module):
|
|||||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
seq_idx: Optional[torch.IntTensor] = None,
|
||||||
):
|
):
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||||
@@ -360,7 +392,7 @@ class BambaMixer(nn.Module):
|
|||||||
A,
|
A,
|
||||||
D=self.D,
|
D=self.D,
|
||||||
chunk_size=self.chunk_size,
|
chunk_size=self.chunk_size,
|
||||||
seq_idx=None, # was seq_idx
|
seq_idx=seq_idx,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
rmsnorm_weight=self.norm.weight,
|
rmsnorm_weight=self.norm.weight,
|
||||||
rmsnorm_eps=self.norm.variance_epsilon,
|
rmsnorm_eps=self.norm.variance_epsilon,
|
||||||
@@ -401,6 +433,7 @@ class BambaMixer(nn.Module):
|
|||||||
weight=self.conv1d.weight.squeeze(1),
|
weight=self.conv1d.weight.squeeze(1),
|
||||||
bias=self.conv1d.bias,
|
bias=self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
|
seq_idx=seq_idx,
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
|
|
||||||
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
||||||
@@ -420,7 +453,7 @@ class BambaMixer(nn.Module):
|
|||||||
chunk_size=self.chunk_size,
|
chunk_size=self.chunk_size,
|
||||||
D=self.D,
|
D=self.D,
|
||||||
z=None,
|
z=None,
|
||||||
seq_idx=None,
|
seq_idx=seq_idx,
|
||||||
return_final_states=True,
|
return_final_states=True,
|
||||||
dt_bias=self.dt_bias,
|
dt_bias=self.dt_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
@@ -654,9 +687,15 @@ class BambaMixer(nn.Module):
|
|||||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
seq_idx: Optional[torch.IntTensor] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
||||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
|
||||||
|
if seq_idx is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
|
||||||
|
)
|
||||||
dtype = hidden_states.dtype
|
dtype = hidden_states.dtype
|
||||||
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
||||||
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
||||||
@@ -701,7 +740,7 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
|
|||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs: Unpack[BambaFlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -721,8 +760,8 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
|
|||||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||||
with `head_dim` being the embedding dimension of each attention head.
|
with `head_dim` being the embedding dimension of each attention head.
|
||||||
kwargs (`dict`, *optional*):
|
kwargs (`dict`, *optional*):
|
||||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
|
||||||
into the model
|
padding-free training and/or improve torch.compile performance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -736,6 +775,7 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
|
|||||||
cache_params=past_key_value,
|
cache_params=past_key_value,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self_attn_weights = None
|
self_attn_weights = None
|
||||||
elif self.layer_type == "attention":
|
elif self.layer_type == "attention":
|
||||||
@@ -838,7 +878,7 @@ class BambaModel(BambaPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs, # NOOP kwargs, for now
|
**kwargs: Unpack[BambaFlashAttentionKwargs],
|
||||||
) -> BaseModelOutputWithPast:
|
) -> BaseModelOutputWithPast:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -890,7 +930,7 @@ class BambaModel(BambaPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
partial(decoder_layer.__call__, **kwargs),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_mask,
|
layer_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
@@ -910,6 +950,7 @@ class BambaModel(BambaPreTrainedModel):
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|||||||
@@ -439,6 +439,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
|||||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
seq_idx: Optional[torch.IntTensor] = None,
|
||||||
):
|
):
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||||
@@ -521,7 +522,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
|||||||
A,
|
A,
|
||||||
D=self.D,
|
D=self.D,
|
||||||
chunk_size=self.chunk_size,
|
chunk_size=self.chunk_size,
|
||||||
seq_idx=None, # was seq_idx
|
seq_idx=seq_idx,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
rmsnorm_weight=self.norm.weight,
|
rmsnorm_weight=self.norm.weight,
|
||||||
rmsnorm_eps=self.norm.variance_epsilon,
|
rmsnorm_eps=self.norm.variance_epsilon,
|
||||||
@@ -562,6 +563,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
|||||||
weight=self.conv1d.weight.squeeze(1),
|
weight=self.conv1d.weight.squeeze(1),
|
||||||
bias=self.conv1d.bias,
|
bias=self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
|
seq_idx=seq_idx,
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
|
|
||||||
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
||||||
@@ -581,7 +583,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
|||||||
chunk_size=self.chunk_size,
|
chunk_size=self.chunk_size,
|
||||||
D=self.D,
|
D=self.D,
|
||||||
z=None,
|
z=None,
|
||||||
seq_idx=None,
|
seq_idx=seq_idx,
|
||||||
return_final_states=True,
|
return_final_states=True,
|
||||||
dt_bias=self.dt_bias,
|
dt_bias=self.dt_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
@@ -815,9 +817,15 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
|||||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
seq_idx: Optional[torch.IntTensor] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
||||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
|
||||||
|
if seq_idx is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
|
||||||
|
)
|
||||||
dtype = hidden_states.dtype
|
dtype = hidden_states.dtype
|
||||||
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
||||||
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
||||||
|
|||||||
@@ -14,16 +14,25 @@
|
|||||||
"""Testing suite for the PyTorch Bamba model."""
|
"""Testing suite for the PyTorch Bamba model."""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pytest import mark
|
||||||
|
|
||||||
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
BambaConfig,
|
||||||
|
DataCollatorWithFlattening,
|
||||||
|
is_torch_available,
|
||||||
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
Expectations,
|
Expectations,
|
||||||
require_deterministic_for_xpu,
|
require_deterministic_for_xpu,
|
||||||
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
|
require_torch_gpu,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@@ -489,6 +498,92 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
# They should result in very similar logits
|
# They should result in very similar logits
|
||||||
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
|
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
|
||||||
|
)
|
||||||
|
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
|
||||||
|
)
|
||||||
|
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_idx_and_fa_kwargs(self):
|
||||||
|
if not self.has_attentions:
|
||||||
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
|
max_new_tokens = 30
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
if not model_class._supports_flash_attn_2:
|
||||||
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
|
||||||
|
self.skipTest("Model dummy inputs should contain padding in their attention mask")
|
||||||
|
|
||||||
|
dummy_input = inputs_dict[model_class.main_input_name]
|
||||||
|
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||||
|
dummy_input = dummy_input.to(torch.float16)
|
||||||
|
|
||||||
|
# make sure that all models have enough positions for generation
|
||||||
|
if hasattr(config, "max_position_embeddings"):
|
||||||
|
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||||
|
self.skipTest("Model does not support position_ids")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
# ensure left padding, to adapt for some models
|
||||||
|
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||||
|
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||||
|
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||||
|
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||||
|
|
||||||
|
model = (
|
||||||
|
model_class.from_pretrained(
|
||||||
|
tmpdirname,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
)
|
||||||
|
.to(torch_device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
|
||||||
|
# flatten
|
||||||
|
features = [
|
||||||
|
{"input_ids": i[a.bool()].tolist()}
|
||||||
|
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
|
||||||
|
]
|
||||||
|
|
||||||
|
# add position_ids + fa_kwargs + seq_idx
|
||||||
|
data_collator = DataCollatorWithFlattening(
|
||||||
|
return_tensors="pt", return_seq_idx=True, return_flash_attn_kwargs=True
|
||||||
|
)
|
||||||
|
batch = data_collator(features)
|
||||||
|
batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()}
|
||||||
|
|
||||||
|
res_padded = model(**inputs_dict)
|
||||||
|
res_padfree = model(**batch_cuda)
|
||||||
|
|
||||||
|
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||||
|
logits_padfree = res_padfree.logits[0]
|
||||||
|
|
||||||
|
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||||
|
# acceptable numerical instability
|
||||||
|
tol = torch.finfo(torch.float16).eps
|
||||||
|
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user