[attention] fix test for packed padfree masking (#39582)

* fix most tests

* skip a few more tests

* address comments

* fix chameleon tests

* forgot to uncomment

* qwen has its own tests with images, rename it as well
This commit is contained in:
Raushan Turganbay
2025-07-25 09:44:52 +02:00
committed by GitHub
parent 565c035a2e
commit c392d47c9b
17 changed files with 153 additions and 250 deletions

View File

@@ -270,7 +270,7 @@ class ChameleonVision2SeqModelTester(ChameleonModelTester):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, : self.image_seq_length] = self.image_token_id
attention_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size])
config = self.get_config()
@@ -325,6 +325,14 @@ class ChameleonVision2SeqModelTest(ModelTesterMixin, GenerationTesterMixin, unit
def test_model_is_small(self):
pass
@unittest.skip("Chameleon applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Chameleon applies key/query norm which doesn't work with packing")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong

View File

@@ -89,7 +89,7 @@ class Emu3Text2TextModelTester:
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = input_ids.ne(1).to(torch_device)
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
config = self.get_config()
@@ -234,9 +234,9 @@ class Emu3Vision2TextModelTester:
config = self.get_config()
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size)
attention_mask = input_ids.ne(1).to(torch_device)
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, : self.image_seq_length] = self.image_token_id
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
pixel_values = floats_tensor(
[

View File

@@ -214,6 +214,14 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_generate_continue_from_inputs_embeds():
pass
@unittest.skip("Persimmon backbone applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Persimmon backbone applies key/query norm which doesn't work with packing")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@slow
@require_torch_accelerator

View File

@@ -143,6 +143,14 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
class Gemma3Vision2TextModelTester:
def __init__(

View File

@@ -465,6 +465,14 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
):
pass
@unittest.skip("KOSMOS-2 doesn't support padding")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("KOSMOS-2 doesn't support padding")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@pytest.mark.generate
def test_left_padding_compatibility(self):
# Overwrite because Kosmos-2 need to padd pixel values and pad image-attn-mask

View File

@@ -152,9 +152,10 @@ class LlavaVisionText2TextModelTester:
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(1).to(torch_device)
input_ids[input_ids == config.image_token_index] = self.pad_token_id
input_ids[:, : self.num_image_tokens] = config.image_token_index
attention_mask = input_ids.ne(1).to(torch_device)
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,

View File

@@ -214,27 +214,27 @@ class MiniMaxModelTest(CausalLMModelTest, unittest.TestCase):
batch_size, seq_length = inputs["input_ids"].shape
self._check_past_key_values_for_generate(batch_size, past_kv, seq_length, config)
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_assisted_decoding_sample(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_assisted_decoding_matches_greedy_search_0_random(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_assisted_decoding_matches_greedy_search_1_same(self):
pass
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@@ -242,6 +242,14 @@ class MiniMaxModelTest(CausalLMModelTest, unittest.TestCase):
def test_attention_outputs(self):
pass
@unittest.skip("MiniMax is special")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("MiniMax is special")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@require_torch
@require_torch_accelerator

View File

@@ -290,6 +290,14 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Paligemma position ids are 1 indexed")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Paloigemma position ids are 1 indexed")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
def test_attention_mask_with_token_types(self):
"""Test that attention masking works correctly both with and without token type IDs."""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -321,3 +321,11 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
@is_flaky
def test_generate_compile_model_forward(self):
super().test_generate_compile_model_forward()
@unittest.skip("Paligemma position ids are 1 indexed")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Paligemma position ids are 1 indexed")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass

View File

@@ -76,6 +76,14 @@ class PersimmonModelTest(CausalLMModelTest, unittest.TestCase):
test_headmasking = False
test_pruning = False
@unittest.skip("Persimmon applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Persimmon applies key/query norm which doesn't work with packing")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@require_torch
class PersimmonIntegrationTest(unittest.TestCase):

View File

@@ -332,7 +332,7 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
def flash_attention_padding_matches_padding_free_with_position_ids(
def attention_mask_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
max_new_tokens = 30

View File

@@ -325,7 +325,7 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
)
self.assertIsNotNone(outputs)
def flash_attention_padding_matches_padding_free_with_position_ids(
def attention_mask_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
max_new_tokens = 30

View File

@@ -283,7 +283,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4
)
def flash_attention_padding_matches_padding_free_with_position_ids(
def attention_mask_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
max_new_tokens = 30

View File

@@ -60,6 +60,7 @@ class VoxtralModelTester:
"use_mrope": False,
"vocab_size": 99,
"head_dim": 8,
"pad_token_id": 0,
},
is_training=True,
audio_config={