[attention] fix test for packed padfree masking (#39582)

* fix most tests

* skip a few more tests

* address comments

* fix chameleon tests

* forgot to uncomment

* qwen has its own tests with images, rename it as well
This commit is contained in:
Raushan Turganbay
2025-07-25 09:44:52 +02:00
committed by GitHub
parent 565c035a2e
commit c392d47c9b
17 changed files with 153 additions and 250 deletions

View File

@@ -270,7 +270,7 @@ class ChameleonVision2SeqModelTester(ChameleonModelTester):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, : self.image_seq_length] = self.image_token_id
attention_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size])
config = self.get_config()
@@ -325,6 +325,14 @@ class ChameleonVision2SeqModelTest(ModelTesterMixin, GenerationTesterMixin, unit
def test_model_is_small(self):
pass
@unittest.skip("Chameleon applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Chameleon applies key/query norm which doesn't work with packing")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong

View File

@@ -89,7 +89,7 @@ class Emu3Text2TextModelTester:
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = input_ids.ne(1).to(torch_device)
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
config = self.get_config()
@@ -234,9 +234,9 @@ class Emu3Vision2TextModelTester:
config = self.get_config()
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size)
attention_mask = input_ids.ne(1).to(torch_device)
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, : self.image_seq_length] = self.image_token_id
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
pixel_values = floats_tensor(
[

View File

@@ -214,6 +214,14 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_generate_continue_from_inputs_embeds():
pass
@unittest.skip("Persimmon backbone applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Persimmon backbone applies key/query norm which doesn't work with packing")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@slow
@require_torch_accelerator

View File

@@ -143,6 +143,14 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
class Gemma3Vision2TextModelTester:
def __init__(

View File

@@ -465,6 +465,14 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
):
pass
@unittest.skip("KOSMOS-2 doesn't support padding")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("KOSMOS-2 doesn't support padding")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@pytest.mark.generate
def test_left_padding_compatibility(self):
# Overwrite because Kosmos-2 need to padd pixel values and pad image-attn-mask

View File

@@ -152,9 +152,10 @@ class LlavaVisionText2TextModelTester:
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(1).to(torch_device)
input_ids[input_ids == config.image_token_index] = self.pad_token_id
input_ids[:, : self.num_image_tokens] = config.image_token_index
attention_mask = input_ids.ne(1).to(torch_device)
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,

View File

@@ -214,27 +214,27 @@ class MiniMaxModelTest(CausalLMModelTest, unittest.TestCase):
batch_size, seq_length = inputs["input_ids"].shape
self._check_past_key_values_for_generate(batch_size, past_kv, seq_length, config)
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_assisted_decoding_sample(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_assisted_decoding_matches_greedy_search_0_random(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_assisted_decoding_matches_greedy_search_1_same(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@@ -242,6 +242,14 @@ class MiniMaxModelTest(CausalLMModelTest, unittest.TestCase):
def test_attention_outputs(self):
pass
@unittest.skip("MiniMax is special")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("MiniMax is special")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@require_torch
@require_torch_accelerator

View File

@@ -290,6 +290,14 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Paligemma position ids are 1 indexed")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Paloigemma position ids are 1 indexed")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
def test_attention_mask_with_token_types(self):
"""Test that attention masking works correctly both with and without token type IDs."""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -321,3 +321,11 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
@is_flaky
def test_generate_compile_model_forward(self):
super().test_generate_compile_model_forward()
@unittest.skip("Paligemma position ids are 1 indexed")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Paligemma position ids are 1 indexed")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass

View File

@@ -76,6 +76,14 @@ class PersimmonModelTest(CausalLMModelTest, unittest.TestCase):
test_headmasking = False
test_pruning = False
@unittest.skip("Persimmon applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Persimmon applies key/query norm which doesn't work with packing")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@require_torch
class PersimmonIntegrationTest(unittest.TestCase):

View File

@@ -332,7 +332,7 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
def flash_attention_padding_matches_padding_free_with_position_ids(
def attention_mask_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
max_new_tokens = 30

View File

@@ -325,7 +325,7 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
)
self.assertIsNotNone(outputs)
def flash_attention_padding_matches_padding_free_with_position_ids(
def attention_mask_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
max_new_tokens = 30

View File

@@ -283,7 +283,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4
)
def flash_attention_padding_matches_padding_free_with_position_ids(
def attention_mask_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
max_new_tokens = 30

View File

@@ -60,6 +60,7 @@ class VoxtralModelTester:
"use_mrope": False,
"vocab_size": 99,
"head_dim": 8,
"pad_token_id": 0,
},
is_training=True,
audio_config={

View File

@@ -4125,9 +4125,13 @@ class ModelTesterMixin:
assert not loss.isnan().any()
def flash_attention_padding_matches_padding_free_with_position_ids(
def attention_mask_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
"""
Tests that the given attention implementation can work with packed sequences and infers the mask
from position ids. This test requires the model to use new attention mask API which handles packing.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
@@ -4142,17 +4146,27 @@ class ModelTesterMixin:
if attn_implementation != "eager" and not getattr(model_class, support_flag[attn_implementation]):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
# can't infer if new attn mask API is supported by assume that only model with attention backend support it
if not model_class._supports_attention_backend:
self.skipTest(f"{model_class.__name__} does not support new attention mask API")
if model_class._is_stateful: # non-transformer models most probably have no packing support
self.skipTest(f"{model_class.__name__} doesn't support packing!")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.is_encoder_decoder:
self.skipTest("Model is an encoder-decoder")
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
if "input_ids" not in inputs_dict or inputs_dict["input_ids"].ndim != 2:
self.skipTest("Model dummy inputs should contain text input ids")
# make sure that all models have enough positions for generation
dummy_input_ids = inputs_dict["input_ids"]
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
config.max_position_embeddings = max_new_tokens + dummy_input_ids.shape[1] + 1
model = model_class(config)
if "position_ids" not in inspect.signature(model.forward).parameters:
@@ -4164,11 +4178,14 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# ensure left padding, to adapt for some models
# Drop all keys except for the minimal set. Hard to manipulate with multimodals/head_mask/etc
inputs_dict = {k: v for k, v in inputs_dict.items() if k in ["input_ids", "attention_mask"]}
# Ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
dummy_input_ids[~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
model = (
model_class.from_pretrained(
@@ -4183,8 +4200,7 @@ class ModelTesterMixin:
if fa_kwargs:
# flatten
features = [
{"input_ids": i[a.bool()].tolist()}
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
{"input_ids": i[a.bool()].tolist()} for i, a in zip(dummy_input_ids, dummy_attention_mask)
]
# add position_ids + fa_kwargs
@@ -4194,55 +4210,48 @@ class ModelTesterMixin:
k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()
}
else:
# flatten
padfree_inputs_dict = {
k: v[dummy_attention_mask.bool()].unsqueeze(0)
for k, v in inputs_dict.items()
if not k == "attention_mask"
}
# add position_ids
padfree_inputs_dict["position_ids"] = (
# create packed position_ids
position_ids = (
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
.long()
.unsqueeze(0)
.to(torch_device)
)
padfree_inputs_dict = {
"input_ids": dummy_input_ids[dummy_attention_mask.bool()].unsqueeze(0),
"position_ids": position_ids,
}
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
# We need to do simple forward without cache in order to trigger packed SDPA/flex/eager attention path
res_padded = model(**inputs_dict, use_cache=False)
res_padfree = model(**padfree_inputs_dict, use_cache=False)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padded = res_padded.logits[dummy_attention_mask.bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
# Mark slow for now as it is failing for all multimodals/non-transformer arch models and a few LLMs
# FIXME @raushan
@slow
def test_eager_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
@slow
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(
self.attention_mask_padding_matches_padding_free_with_position_ids(
attn_implementation="flash_attention_2", fa_kwargs=True
)
@@ -4251,14 +4260,14 @@ class ModelTesterMixin:
@mark.flash_attn_3_test
@slow
def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
@require_flash_attn_3
@require_torch_gpu
@mark.flash_attn_3_test
@slow
def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(
self.attention_mask_padding_matches_padding_free_with_position_ids(
attn_implementation="flash_attention_3", fa_kwargs=True
)