[attention] fix test for packed padfree masking (#39582)

* fix most tests

* skip a few more tests

* address comments

* fix chameleon tests

* forgot to uncomment

* qwen has its own tests with images, rename it as well
This commit is contained in:
Raushan Turganbay
2025-07-25 09:44:52 +02:00
committed by GitHub
parent 565c035a2e
commit c392d47c9b
17 changed files with 153 additions and 250 deletions

View File

@@ -408,69 +408,31 @@ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
def set_input_embeddings(self, new_embeddings):
self.input_embeds_layer = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, cache_position=None, **kwargs):
# Overwritten -- bark has a model-specific hack
input_embeds = kwargs.get("input_embeds", None)
def prepare_inputs_for_generation(
self,
input_ids,
attention_mask=None,
input_embeds=None,
past_key_values=None,
position_ids=None,
use_cache=None,
cache_position=None,
**kwargs,
):
# Overwritten -- bark uses `input_embeds` not `inputS_embeds`
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if cache_position[0] != 0:
# Omit tokens covered by past_key_values
seq_len = input_ids.shape[1]
past_length = past_key_values.get_seq_length()
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# input_embeds have already been used and is not required anymore
input_embeds = None
else:
if input_embeds is not None and kwargs.get("use_cache"):
seq_len = input_embeds.shape[1]
else:
seq_len = input_ids.shape[1]
# ensure that attention_mask and position_ids shapes are aligned with the weird Bark hack of reducing
# sequence length on the first forward pass
if attention_mask is not None:
attention_mask = attention_mask[:, :seq_len]
if position_ids is not None:
position_ids = position_ids[:, :seq_len]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None
if input_embeds is not None and kwargs.get("use_cache"):
return {
"input_ids": None,
"input_embeds": input_embeds,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"cache_position": cache_position,
}
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"cache_position": cache_position,
}
model_inputs = super().prepare_inputs_for_generation(
input_ids,
attention_mask=attention_mask,
inputs_embeds=input_embeds,
past_key_values=past_key_values,
position_ids=position_ids,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
model_inputs["input_embeds"] = model_inputs.pop("inputs_embeds", None)
return model_inputs
@auto_docstring
def forward(
@@ -546,7 +508,7 @@ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if position_ids is None:
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)

View File

@@ -25,7 +25,7 @@ from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -35,19 +35,12 @@ from ...utils import (
TransformersKwargs,
auto_docstring,
can_return_tuple,
is_torch_flex_attn_available,
is_torchdynamo_compiling,
logging,
)
from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@@ -353,15 +346,8 @@ class ChameleonAttention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
@@ -942,7 +928,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
else:
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
n_image_tokens_in_text = (special_image_mask).sum()
n_image_tokens_in_text = special_image_mask.sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = self.get_image_features(pixel_values)
@@ -966,8 +952,13 @@ class ChameleonModel(ChameleonPreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# embed positions
@@ -1015,131 +1006,6 @@ class ChameleonModel(ChameleonPreTrainedModel):
attentions=all_self_attns,
)
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
@auto_docstring(
custom_intro="""

View File

@@ -270,7 +270,7 @@ class ChameleonVision2SeqModelTester(ChameleonModelTester):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, : self.image_seq_length] = self.image_token_id
attention_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size])
config = self.get_config()
@@ -325,6 +325,14 @@ class ChameleonVision2SeqModelTest(ModelTesterMixin, GenerationTesterMixin, unit
def test_model_is_small(self):
pass
@unittest.skip("Chameleon applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Chameleon applies key/query norm which doesn't work with packing")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong

View File

@@ -89,7 +89,7 @@ class Emu3Text2TextModelTester:
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = input_ids.ne(1).to(torch_device)
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
config = self.get_config()
@@ -234,9 +234,9 @@ class Emu3Vision2TextModelTester:
config = self.get_config()
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size)
attention_mask = input_ids.ne(1).to(torch_device)
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, : self.image_seq_length] = self.image_token_id
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
pixel_values = floats_tensor(
[

View File

@@ -214,6 +214,14 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_generate_continue_from_inputs_embeds():
pass
@unittest.skip("Persimmon backbone applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Persimmon backbone applies key/query norm which doesn't work with packing")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@slow
@require_torch_accelerator

View File

@@ -143,6 +143,14 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
class Gemma3Vision2TextModelTester:
def __init__(

View File

@@ -465,6 +465,14 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
):
pass
@unittest.skip("KOSMOS-2 doesn't support padding")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("KOSMOS-2 doesn't support padding")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@pytest.mark.generate
def test_left_padding_compatibility(self):
# Overwrite because Kosmos-2 need to padd pixel values and pad image-attn-mask

View File

@@ -152,9 +152,10 @@ class LlavaVisionText2TextModelTester:
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(1).to(torch_device)
input_ids[input_ids == config.image_token_index] = self.pad_token_id
input_ids[:, : self.num_image_tokens] = config.image_token_index
attention_mask = input_ids.ne(1).to(torch_device)
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,

View File

@@ -214,27 +214,27 @@ class MiniMaxModelTest(CausalLMModelTest, unittest.TestCase):
batch_size, seq_length = inputs["input_ids"].shape
self._check_past_key_values_for_generate(batch_size, past_kv, seq_length, config)
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_assisted_decoding_sample(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_assisted_decoding_matches_greedy_search_0_random(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_assisted_decoding_matches_greedy_search_1_same(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@@ -242,6 +242,14 @@ class MiniMaxModelTest(CausalLMModelTest, unittest.TestCase):
def test_attention_outputs(self):
pass
@unittest.skip("MiniMax is special")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("MiniMax is special")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@require_torch
@require_torch_accelerator

View File

@@ -290,6 +290,14 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Paligemma position ids are 1 indexed")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Paloigemma position ids are 1 indexed")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
def test_attention_mask_with_token_types(self):
"""Test that attention masking works correctly both with and without token type IDs."""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -321,3 +321,11 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
@is_flaky
def test_generate_compile_model_forward(self):
super().test_generate_compile_model_forward()
@unittest.skip("Paligemma position ids are 1 indexed")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Paligemma position ids are 1 indexed")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass

View File

@@ -76,6 +76,14 @@ class PersimmonModelTest(CausalLMModelTest, unittest.TestCase):
test_headmasking = False
test_pruning = False
@unittest.skip("Persimmon applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Persimmon applies key/query norm which doesn't work with packing")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@require_torch
class PersimmonIntegrationTest(unittest.TestCase):

View File

@@ -332,7 +332,7 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
def flash_attention_padding_matches_padding_free_with_position_ids(
def attention_mask_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
max_new_tokens = 30

View File

@@ -325,7 +325,7 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
)
self.assertIsNotNone(outputs)
def flash_attention_padding_matches_padding_free_with_position_ids(
def attention_mask_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
max_new_tokens = 30

View File

@@ -283,7 +283,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4
)
def flash_attention_padding_matches_padding_free_with_position_ids(
def attention_mask_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
max_new_tokens = 30

View File

@@ -60,6 +60,7 @@ class VoxtralModelTester:
"use_mrope": False,
"vocab_size": 99,
"head_dim": 8,
"pad_token_id": 0,
},
is_training=True,
audio_config={

View File

@@ -4125,9 +4125,13 @@ class ModelTesterMixin:
assert not loss.isnan().any()
def flash_attention_padding_matches_padding_free_with_position_ids(
def attention_mask_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
"""
Tests that the given attention implementation can work with packed sequences and infers the mask
from position ids. This test requires the model to use new attention mask API which handles packing.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
@@ -4142,17 +4146,27 @@ class ModelTesterMixin:
if attn_implementation != "eager" and not getattr(model_class, support_flag[attn_implementation]):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
# can't infer if new attn mask API is supported by assume that only model with attention backend support it
if not model_class._supports_attention_backend:
self.skipTest(f"{model_class.__name__} does not support new attention mask API")
if model_class._is_stateful: # non-transformer models most probably have no packing support
self.skipTest(f"{model_class.__name__} doesn't support packing!")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.is_encoder_decoder:
self.skipTest("Model is an encoder-decoder")
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
if "input_ids" not in inputs_dict or inputs_dict["input_ids"].ndim != 2:
self.skipTest("Model dummy inputs should contain text input ids")
# make sure that all models have enough positions for generation
dummy_input_ids = inputs_dict["input_ids"]
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
config.max_position_embeddings = max_new_tokens + dummy_input_ids.shape[1] + 1
model = model_class(config)
if "position_ids" not in inspect.signature(model.forward).parameters:
@@ -4164,11 +4178,14 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# ensure left padding, to adapt for some models
# Drop all keys except for the minimal set. Hard to manipulate with multimodals/head_mask/etc
inputs_dict = {k: v for k, v in inputs_dict.items() if k in ["input_ids", "attention_mask"]}
# Ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
dummy_input_ids[~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
model = (
model_class.from_pretrained(
@@ -4183,8 +4200,7 @@ class ModelTesterMixin:
if fa_kwargs:
# flatten
features = [
{"input_ids": i[a.bool()].tolist()}
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
{"input_ids": i[a.bool()].tolist()} for i, a in zip(dummy_input_ids, dummy_attention_mask)
]
# add position_ids + fa_kwargs
@@ -4194,55 +4210,48 @@ class ModelTesterMixin:
k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()
}
else:
# flatten
padfree_inputs_dict = {
k: v[dummy_attention_mask.bool()].unsqueeze(0)
for k, v in inputs_dict.items()
if not k == "attention_mask"
}
# add position_ids
padfree_inputs_dict["position_ids"] = (
# create packed position_ids
position_ids = (
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
.long()
.unsqueeze(0)
.to(torch_device)
)
padfree_inputs_dict = {
"input_ids": dummy_input_ids[dummy_attention_mask.bool()].unsqueeze(0),
"position_ids": position_ids,
}
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
# We need to do simple forward without cache in order to trigger packed SDPA/flex/eager attention path
res_padded = model(**inputs_dict, use_cache=False)
res_padfree = model(**padfree_inputs_dict, use_cache=False)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padded = res_padded.logits[dummy_attention_mask.bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
# Mark slow for now as it is failing for all multimodals/non-transformer arch models and a few LLMs
# FIXME @raushan
@slow
def test_eager_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
@slow
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(
self.attention_mask_padding_matches_padding_free_with_position_ids(
attn_implementation="flash_attention_2", fa_kwargs=True
)
@@ -4251,14 +4260,14 @@ class ModelTesterMixin:
@mark.flash_attn_3_test
@slow
def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
@require_flash_attn_3
@require_torch_gpu
@mark.flash_attn_3_test
@slow
def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(
self.attention_mask_padding_matches_padding_free_with_position_ids(
attn_implementation="flash_attention_3", fa_kwargs=True
)