Add padding-free to bamba (#35861)
* add seq_idx and fa kwargs * update tests * docs and grad ckpt support * fmt * better names * test_raise_missing_padding_free_kwarg_errs * + seq_idx in doc strings * padding free training docs * add link to pr plots * raise err on attn_mask with padding free * rm raising missing padding free err test * BambaFlashAttentionKwargs * run modular util for modular_granitemoehybrid.py
This commit is contained in:
@@ -63,6 +63,34 @@ response = model.generate(**inputs, max_new_tokens=64)
|
||||
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
|
||||
```
|
||||
|
||||
|
||||
## Padding-Free Training
|
||||
|
||||
Bamba supports padding-free training in which distinct training examples can be concatenated
|
||||
together while nevertheless processing the inputs as though they belonged to separate batches. When
|
||||
the examples are of varying lengths, padding-free training can provide significant speed ups and
|
||||
memory savings compared to batching the examples together and using padding, as the unnecessary
|
||||
compute and memory due to padding is avoided entirely. The performance gains depend on factors such
|
||||
as the model and the data distribution, but throughput gains up to [~2x are commonly
|
||||
seen](https://github.com/huggingface/transformers/pull/35861#issue-2807873129).
|
||||
|
||||
Using padding-free training with Bamba requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d`
|
||||
packages, and the following arguments must be passed to the model in addition to `input_ids` and
|
||||
`labels`:
|
||||
* `position_ids: torch.LongTensor`: the position index of each token in each sequence.
|
||||
* `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
|
||||
* Each of the [`FlashAttentionKwargs`]
|
||||
* `cu_seq_lens_q: torch.LongTensor`: The cumulative sequence lengths of all queries.
|
||||
* `cu_seq_lens_k: torch.LongTensor`: The cumulative sequence lengths of all keys.
|
||||
* `max_length_q: int`: the longest query length in the batch.
|
||||
* `max_length_k: int`: the longest key length in the batch.
|
||||
|
||||
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] can be used
|
||||
to programmatically generate the above set of additional arguments using `return_seq_idx=True` and
|
||||
`return_flash_attn_kwargs=True`. See [this blog post](https://huggingface.co/blog/packing-with-FA2)
|
||||
for additional information.
|
||||
|
||||
|
||||
[[autodoc]] BambaForCausalLM
|
||||
- forward
|
||||
|
||||
|
||||
@@ -24,7 +24,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -61,6 +62,31 @@ else:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class BambaFlashAttentionKwargs(TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
||||
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
||||
|
||||
Attributes:
|
||||
cu_seq_lens_q (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for query state.
|
||||
cu_seq_lens_k (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for key state.
|
||||
max_length_q (`int`):
|
||||
Maximum sequence length for query state.
|
||||
max_length_k (`int`):
|
||||
Maximum sequence length for key state.
|
||||
seq_idx (`torch.IntTensor):
|
||||
Index of each packed sequence.
|
||||
"""
|
||||
|
||||
cu_seq_lens_q: torch.LongTensor
|
||||
cu_seq_lens_k: torch.LongTensor
|
||||
max_length_q: int
|
||||
max_length_k: int
|
||||
seq_idx: torch.IntTensor
|
||||
|
||||
|
||||
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
|
||||
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
|
||||
"""
|
||||
@@ -487,6 +513,7 @@ class BambaMixer(nn.Module):
|
||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.IntTensor] = None,
|
||||
):
|
||||
# 1. Gated MLP's linear projection
|
||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||
@@ -569,7 +596,7 @@ class BambaMixer(nn.Module):
|
||||
A,
|
||||
D=self.D,
|
||||
chunk_size=self.chunk_size,
|
||||
seq_idx=None, # was seq_idx
|
||||
seq_idx=seq_idx,
|
||||
activation=self.activation,
|
||||
rmsnorm_weight=self.norm.weight,
|
||||
rmsnorm_eps=self.norm.variance_epsilon,
|
||||
@@ -610,6 +637,7 @@ class BambaMixer(nn.Module):
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
seq_idx=seq_idx,
|
||||
).transpose(1, 2)
|
||||
|
||||
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
||||
@@ -629,7 +657,7 @@ class BambaMixer(nn.Module):
|
||||
chunk_size=self.chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
seq_idx=None,
|
||||
seq_idx=seq_idx,
|
||||
return_final_states=True,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
@@ -863,9 +891,15 @@ class BambaMixer(nn.Module):
|
||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.IntTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
|
||||
if seq_idx is not None:
|
||||
raise NotImplementedError(
|
||||
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
|
||||
)
|
||||
dtype = hidden_states.dtype
|
||||
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
||||
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
||||
@@ -939,7 +973,7 @@ class BambaDecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs,
|
||||
**kwargs: Unpack[BambaFlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -959,8 +993,8 @@ class BambaDecoderLayer(nn.Module):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
|
||||
padding-free training and/or improve torch.compile performance.
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
@@ -974,6 +1008,7 @@ class BambaDecoderLayer(nn.Module):
|
||||
cache_params=past_key_value,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
self_attn_weights = None
|
||||
elif self.layer_type == "attention":
|
||||
@@ -1076,7 +1111,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs, # NOOP kwargs, for now
|
||||
**kwargs: Unpack[BambaFlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -1128,7 +1163,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
partial(decoder_layer.__call__, **kwargs),
|
||||
hidden_states,
|
||||
layer_mask,
|
||||
position_ids,
|
||||
@@ -1148,6 +1183,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@@ -19,7 +19,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Bamba model."""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -46,7 +47,12 @@ from transformers.models.mamba2.modeling_mamba2 import (
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from ...utils.import_utils import is_causal_conv1d_available, is_flash_attn_2_available, is_mamba_2_ssm_available
|
||||
from .configuration_bamba import BambaConfig
|
||||
|
||||
@@ -71,6 +77,31 @@ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_c
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class BambaFlashAttentionKwargs(TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
||||
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
||||
|
||||
Attributes:
|
||||
cu_seq_lens_q (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for query state.
|
||||
cu_seq_lens_k (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for key state.
|
||||
max_length_q (`int`):
|
||||
Maximum sequence length for query state.
|
||||
max_length_k (`int`):
|
||||
Maximum sequence length for key state.
|
||||
seq_idx (`torch.IntTensor):
|
||||
Index of each packed sequence.
|
||||
"""
|
||||
|
||||
cu_seq_lens_q: torch.LongTensor
|
||||
cu_seq_lens_k: torch.LongTensor
|
||||
max_length_q: int
|
||||
max_length_k: int
|
||||
seq_idx: torch.IntTensor
|
||||
|
||||
|
||||
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
|
||||
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
|
||||
"""
|
||||
@@ -278,6 +309,7 @@ class BambaMixer(nn.Module):
|
||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.IntTensor] = None,
|
||||
):
|
||||
# 1. Gated MLP's linear projection
|
||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||
@@ -360,7 +392,7 @@ class BambaMixer(nn.Module):
|
||||
A,
|
||||
D=self.D,
|
||||
chunk_size=self.chunk_size,
|
||||
seq_idx=None, # was seq_idx
|
||||
seq_idx=seq_idx,
|
||||
activation=self.activation,
|
||||
rmsnorm_weight=self.norm.weight,
|
||||
rmsnorm_eps=self.norm.variance_epsilon,
|
||||
@@ -401,6 +433,7 @@ class BambaMixer(nn.Module):
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
seq_idx=seq_idx,
|
||||
).transpose(1, 2)
|
||||
|
||||
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
||||
@@ -420,7 +453,7 @@ class BambaMixer(nn.Module):
|
||||
chunk_size=self.chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
seq_idx=None,
|
||||
seq_idx=seq_idx,
|
||||
return_final_states=True,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
@@ -654,9 +687,15 @@ class BambaMixer(nn.Module):
|
||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.IntTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
|
||||
if seq_idx is not None:
|
||||
raise NotImplementedError(
|
||||
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
|
||||
)
|
||||
dtype = hidden_states.dtype
|
||||
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
||||
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
||||
@@ -701,7 +740,7 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs,
|
||||
**kwargs: Unpack[BambaFlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -721,8 +760,8 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
|
||||
padding-free training and/or improve torch.compile performance.
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
@@ -736,6 +775,7 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
|
||||
cache_params=past_key_value,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
self_attn_weights = None
|
||||
elif self.layer_type == "attention":
|
||||
@@ -838,7 +878,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs, # NOOP kwargs, for now
|
||||
**kwargs: Unpack[BambaFlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -890,7 +930,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
partial(decoder_layer.__call__, **kwargs),
|
||||
hidden_states,
|
||||
layer_mask,
|
||||
position_ids,
|
||||
@@ -910,6 +950,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@@ -439,6 +439,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.IntTensor] = None,
|
||||
):
|
||||
# 1. Gated MLP's linear projection
|
||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||
@@ -521,7 +522,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
||||
A,
|
||||
D=self.D,
|
||||
chunk_size=self.chunk_size,
|
||||
seq_idx=None, # was seq_idx
|
||||
seq_idx=seq_idx,
|
||||
activation=self.activation,
|
||||
rmsnorm_weight=self.norm.weight,
|
||||
rmsnorm_eps=self.norm.variance_epsilon,
|
||||
@@ -562,6 +563,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
seq_idx=seq_idx,
|
||||
).transpose(1, 2)
|
||||
|
||||
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
||||
@@ -581,7 +583,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
||||
chunk_size=self.chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
seq_idx=None,
|
||||
seq_idx=seq_idx,
|
||||
return_final_states=True,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
@@ -815,9 +817,15 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.IntTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
|
||||
if seq_idx is not None:
|
||||
raise NotImplementedError(
|
||||
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
|
||||
)
|
||||
dtype = hidden_states.dtype
|
||||
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
||||
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
||||
|
||||
@@ -14,16 +14,25 @@
|
||||
"""Testing suite for the PyTorch Bamba model."""
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BambaConfig,
|
||||
DataCollatorWithFlattening,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
require_deterministic_for_xpu,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -489,6 +498,92 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
# They should result in very similar logits
|
||||
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@unittest.skip(
|
||||
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
|
||||
)
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
|
||||
)
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_idx_and_fa_kwargs(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
max_new_tokens = 30
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
|
||||
self.skipTest("Model dummy inputs should contain padding in their attention mask")
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
self.skipTest("Model does not support position_ids")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# ensure left padding, to adapt for some models
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# flatten
|
||||
features = [
|
||||
{"input_ids": i[a.bool()].tolist()}
|
||||
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
|
||||
]
|
||||
|
||||
# add position_ids + fa_kwargs + seq_idx
|
||||
data_collator = DataCollatorWithFlattening(
|
||||
return_tensors="pt", return_seq_idx=True, return_flash_attn_kwargs=True
|
||||
)
|
||||
batch = data_collator(features)
|
||||
batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()}
|
||||
|
||||
res_padded = model(**inputs_dict)
|
||||
res_padfree = model(**batch_cuda)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.float16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user