From c392d47c9b40a9866663ea3bed97904cec27658b Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 25 Jul 2025 09:44:52 +0200 Subject: [PATCH] [attention] fix test for packed padfree masking (#39582) * fix most tests * skip a few more tests * address comments * fix chameleon tests * forgot to uncomment * qwen has its own tests with images, rename it as well --- src/transformers/models/bark/modeling_bark.py | 88 +++------- .../models/chameleon/modeling_chameleon.py | 154 ++---------------- .../chameleon/test_modeling_chameleon.py | 10 +- tests/models/emu3/test_modeling_emu3.py | 4 +- tests/models/fuyu/test_modeling_fuyu.py | 8 + tests/models/gemma3/test_modeling_gemma3.py | 8 + tests/models/kosmos2/test_modeling_kosmos2.py | 8 + tests/models/llava/test_modeling_llava.py | 3 +- tests/models/minimax/test_modeling_minimax.py | 20 ++- .../paligemma/test_modeling_paligemma.py | 8 + .../paligemma2/test_modeling_paligemma2.py | 8 + .../persimmon/test_modeling_persimmon.py | 8 + .../test_modeling_qwen2_5_omni.py | 2 +- .../qwen2_5_vl/test_modeling_qwen2_5_vl.py | 2 +- .../models/qwen2_vl/test_modeling_qwen2_vl.py | 2 +- tests/models/voxtral/test_modeling_voxtral.py | 1 + tests/test_modeling_common.py | 69 ++++---- 17 files changed, 153 insertions(+), 250 deletions(-) diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 2e199d4fdd..7aa2bb6e20 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -408,69 +408,31 @@ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin): def set_input_embeddings(self, new_embeddings): self.input_embeds_layer = new_embeddings - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, cache_position=None, **kwargs): - # Overwritten -- bark has a model-specific hack - input_embeds = kwargs.get("input_embeds", None) + def prepare_inputs_for_generation( + self, + input_ids, + attention_mask=None, + input_embeds=None, + past_key_values=None, + position_ids=None, + use_cache=None, + cache_position=None, + **kwargs, + ): + # Overwritten -- bark uses `input_embeds` not `inputS_embeds` - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if cache_position[0] != 0: - # Omit tokens covered by past_key_values - seq_len = input_ids.shape[1] - past_length = past_key_values.get_seq_length() - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - # input_embeds have already been used and is not required anymore - input_embeds = None - else: - if input_embeds is not None and kwargs.get("use_cache"): - seq_len = input_embeds.shape[1] - else: - seq_len = input_ids.shape[1] - - # ensure that attention_mask and position_ids shapes are aligned with the weird Bark hack of reducing - # sequence length on the first forward pass - if attention_mask is not None: - attention_mask = attention_mask[:, :seq_len] - if position_ids is not None: - position_ids = position_ids[:, :seq_len] - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - else: - position_ids = None - - if input_embeds is not None and kwargs.get("use_cache"): - return { - "input_ids": None, - "input_embeds": input_embeds, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "cache_position": cache_position, - } - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "cache_position": cache_position, - } + model_inputs = super().prepare_inputs_for_generation( + input_ids, + attention_mask=attention_mask, + inputs_embeds=input_embeds, + past_key_values=past_key_values, + position_ids=position_ids, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + model_inputs["input_embeds"] = model_inputs.pop("inputs_embeds", None) + return model_inputs @auto_docstring def forward( @@ -546,7 +508,7 @@ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin): return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index b70387d595..b003f9e4e8 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -25,7 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -35,19 +35,12 @@ from ...utils import ( TransformersKwargs, auto_docstring, can_return_tuple, - is_torch_flex_attn_available, is_torchdynamo_compiling, logging, ) from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -353,15 +346,8 @@ class ChameleonAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -942,7 +928,7 @@ class ChameleonModel(ChameleonPreTrainedModel): else: special_image_mask = input_ids == self.vocabulary_mapping.image_token_id - n_image_tokens_in_text = (special_image_mask).sum() + n_image_tokens_in_text = special_image_mask.sum() special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_embeds = self.get_image_features(pixel_values) @@ -966,8 +952,13 @@ class ChameleonModel(ChameleonPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, ) # embed positions @@ -1015,131 +1006,6 @@ class ChameleonModel(ChameleonPreTrainedModel): attentions=all_self_attns, ) - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - @auto_docstring( custom_intro=""" diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index 67baab37c0..35de309661 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -270,7 +270,7 @@ class ChameleonVision2SeqModelTester(ChameleonModelTester): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids[input_ids == self.image_token_id] = self.pad_token_id input_ids[:, : self.image_seq_length] = self.image_token_id - attention_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + attention_mask = input_ids.ne(self.pad_token_id).to(torch_device) pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size]) config = self.get_config() @@ -325,6 +325,14 @@ class ChameleonVision2SeqModelTest(ModelTesterMixin, GenerationTesterMixin, unit def test_model_is_small(self): pass + @unittest.skip("Chameleon applies key/query norm which doesn't work with packing") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("Chameleon applies key/query norm which doesn't work with packing") + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass + def test_mismatching_num_image_tokens(self): """ Tests that VLMs through an error with explicit message saying what is wrong diff --git a/tests/models/emu3/test_modeling_emu3.py b/tests/models/emu3/test_modeling_emu3.py index 978febaa49..6c4f718590 100644 --- a/tests/models/emu3/test_modeling_emu3.py +++ b/tests/models/emu3/test_modeling_emu3.py @@ -89,7 +89,7 @@ class Emu3Text2TextModelTester: def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - attention_mask = input_ids.ne(1).to(torch_device) + attention_mask = input_ids.ne(self.pad_token_id).to(torch_device) config = self.get_config() @@ -234,9 +234,9 @@ class Emu3Vision2TextModelTester: config = self.get_config() input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size) - attention_mask = input_ids.ne(1).to(torch_device) input_ids[input_ids == self.image_token_id] = self.pad_token_id input_ids[:, : self.image_seq_length] = self.image_token_id + attention_mask = input_ids.ne(self.pad_token_id).to(torch_device) pixel_values = floats_tensor( [ diff --git a/tests/models/fuyu/test_modeling_fuyu.py b/tests/models/fuyu/test_modeling_fuyu.py index 1be8d9fcca..6ca7b23af6 100644 --- a/tests/models/fuyu/test_modeling_fuyu.py +++ b/tests/models/fuyu/test_modeling_fuyu.py @@ -214,6 +214,14 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin def test_generate_continue_from_inputs_embeds(): pass + @unittest.skip("Persimmon backbone applies key/query norm which doesn't work with packing") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("Persimmon backbone applies key/query norm which doesn't work with packing") + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass + @slow @require_torch_accelerator diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 3817acfd50..65fcf547b4 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -143,6 +143,14 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase def test_multi_gpu_data_parallel_forward(self): pass + @unittest.skip("Gemma3 applies key/query norm which doesn't work with packing") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("Gemma3 applies key/query norm which doesn't work with packing") + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass + class Gemma3Vision2TextModelTester: def __init__( diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py index e23b8672d2..6bb86406ce 100644 --- a/tests/models/kosmos2/test_modeling_kosmos2.py +++ b/tests/models/kosmos2/test_modeling_kosmos2.py @@ -465,6 +465,14 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ): pass + @unittest.skip("KOSMOS-2 doesn't support padding") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("KOSMOS-2 doesn't support padding") + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass + @pytest.mark.generate def test_left_padding_compatibility(self): # Overwrite because Kosmos-2 need to padd pixel values and pad image-attn-mask diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 39eed1c45d..f8d61fb62b 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -152,9 +152,10 @@ class LlavaVisionText2TextModelTester: config_and_inputs = self.prepare_config_and_inputs() config, pixel_values = config_and_inputs input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 - attention_mask = input_ids.ne(1).to(torch_device) input_ids[input_ids == config.image_token_index] = self.pad_token_id input_ids[:, : self.num_image_tokens] = config.image_token_index + attention_mask = input_ids.ne(1).to(torch_device) + inputs_dict = { "pixel_values": pixel_values, "input_ids": input_ids, diff --git a/tests/models/minimax/test_modeling_minimax.py b/tests/models/minimax/test_modeling_minimax.py index d827ea8620..475f9a2d07 100644 --- a/tests/models/minimax/test_modeling_minimax.py +++ b/tests/models/minimax/test_modeling_minimax.py @@ -214,27 +214,27 @@ class MiniMaxModelTest(CausalLMModelTest, unittest.TestCase): batch_size, seq_length = inputs["input_ids"].shape self._check_past_key_values_for_generate(batch_size, past_kv, seq_length, config) - @unittest.skip(reason="MiniMaxCache doesnot support `crop()` method") + @unittest.skip(reason="MiniMaxCache does not support `crop()` method") def test_prompt_lookup_decoding_matches_greedy_search(self): pass - @unittest.skip(reason="MiniMaxCache doesnot support `crop()` method") + @unittest.skip(reason="MiniMaxCache does not support `crop()` method") def test_contrastive_generate_low_memory(self): pass - @unittest.skip(reason="MiniMaxCache doesnot support `crop()` method") + @unittest.skip(reason="MiniMaxCache does not support `crop()` method") def test_assisted_decoding_sample(self): pass - @unittest.skip(reason="MiniMaxCache doesnot support `crop()` method") + @unittest.skip(reason="MiniMaxCache does not support `crop()` method") def test_assisted_decoding_matches_greedy_search_0_random(self): pass - @unittest.skip(reason="MiniMaxCache doesnot support `crop()` method") + @unittest.skip(reason="MiniMaxCache does not support `crop()` method") def test_assisted_decoding_matches_greedy_search_1_same(self): pass - @unittest.skip(reason="MiniMaxCache doesnot support `crop()` method") + @unittest.skip(reason="MiniMaxCache does not support `crop()` method") def test_contrastive_generate_dict_outputs_use_cache(self): pass @@ -242,6 +242,14 @@ class MiniMaxModelTest(CausalLMModelTest, unittest.TestCase): def test_attention_outputs(self): pass + @unittest.skip("MiniMax is special") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("MiniMax is special") + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass + @require_torch @require_torch_accelerator diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 7bac1dca60..2267e6dc9d 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -290,6 +290,14 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass + @unittest.skip("Paligemma position ids are 1 indexed") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("Paloigemma position ids are 1 indexed") + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass + def test_attention_mask_with_token_types(self): """Test that attention masking works correctly both with and without token type IDs.""" config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/paligemma2/test_modeling_paligemma2.py b/tests/models/paligemma2/test_modeling_paligemma2.py index ba8d7a6cac..7523f83fd9 100644 --- a/tests/models/paligemma2/test_modeling_paligemma2.py +++ b/tests/models/paligemma2/test_modeling_paligemma2.py @@ -321,3 +321,11 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe @is_flaky def test_generate_compile_model_forward(self): super().test_generate_compile_model_forward() + + @unittest.skip("Paligemma position ids are 1 indexed") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("Paligemma position ids are 1 indexed") + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 2ac23a7e30..e683170e05 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -76,6 +76,14 @@ class PersimmonModelTest(CausalLMModelTest, unittest.TestCase): test_headmasking = False test_pruning = False + @unittest.skip("Persimmon applies key/query norm which doesn't work with packing") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("Persimmon applies key/query norm which doesn't work with packing") + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass + @require_torch class PersimmonIntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py index 1d128e69a5..0863d8def4 100644 --- a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py +++ b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py @@ -332,7 +332,7 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - def flash_attention_padding_matches_padding_free_with_position_ids( + def attention_mask_padding_matches_padding_free_with_position_ids( self, attn_implementation: str, fa_kwargs: bool = False ): max_new_tokens = 30 diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index f2b229851f..385a714f13 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -325,7 +325,7 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test ) self.assertIsNotNone(outputs) - def flash_attention_padding_matches_padding_free_with_position_ids( + def attention_mask_padding_matches_padding_free_with_position_ids( self, attn_implementation: str, fa_kwargs: bool = False ): max_new_tokens = 30 diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index 9f37f61108..ea0a8992b4 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -283,7 +283,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4 ) - def flash_attention_padding_matches_padding_free_with_position_ids( + def attention_mask_padding_matches_padding_free_with_position_ids( self, attn_implementation: str, fa_kwargs: bool = False ): max_new_tokens = 30 diff --git a/tests/models/voxtral/test_modeling_voxtral.py b/tests/models/voxtral/test_modeling_voxtral.py index b01d53ce77..5ad7552c01 100644 --- a/tests/models/voxtral/test_modeling_voxtral.py +++ b/tests/models/voxtral/test_modeling_voxtral.py @@ -60,6 +60,7 @@ class VoxtralModelTester: "use_mrope": False, "vocab_size": 99, "head_dim": 8, + "pad_token_id": 0, }, is_training=True, audio_config={ diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c60cd1eb66..57722d2261 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4125,9 +4125,13 @@ class ModelTesterMixin: assert not loss.isnan().any() - def flash_attention_padding_matches_padding_free_with_position_ids( + def attention_mask_padding_matches_padding_free_with_position_ids( self, attn_implementation: str, fa_kwargs: bool = False ): + """ + Tests that the given attention implementation can work with packed sequences and infers the mask + from position ids. This test requires the model to use new attention mask API which handles packing. + """ if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") @@ -4142,17 +4146,27 @@ class ModelTesterMixin: if attn_implementation != "eager" and not getattr(model_class, support_flag[attn_implementation]): self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") + # can't infer if new attn mask API is supported by assume that only model with attention backend support it + if not model_class._supports_attention_backend: + self.skipTest(f"{model_class.__name__} does not support new attention mask API") + + if model_class._is_stateful: # non-transformer models most probably have no packing support + self.skipTest(f"{model_class.__name__} doesn't support packing!") + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if config.is_encoder_decoder: + self.skipTest("Model is an encoder-decoder") + if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict: self.skipTest("Model dummy inputs should contain padding in their attention mask") - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.float16]: - dummy_input = dummy_input.to(torch.bfloat16) + if "input_ids" not in inputs_dict or inputs_dict["input_ids"].ndim != 2: + self.skipTest("Model dummy inputs should contain text input ids") # make sure that all models have enough positions for generation + dummy_input_ids = inputs_dict["input_ids"] if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + config.max_position_embeddings = max_new_tokens + dummy_input_ids.shape[1] + 1 model = model_class(config) if "position_ids" not in inspect.signature(model.forward).parameters: @@ -4164,11 +4178,14 @@ class ModelTesterMixin: with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - # ensure left padding, to adapt for some models + # Drop all keys except for the minimal set. Hard to manipulate with multimodals/head_mask/etc + inputs_dict = {k: v for k, v in inputs_dict.items() if k in ["input_ids", "attention_mask"]} + + # Ensure left padding, to adapt for some models if 0 in inputs_dict["attention_mask"][:, -1]: inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1) dummy_attention_mask = inputs_dict["attention_mask"] - inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id + dummy_input_ids[~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id model = ( model_class.from_pretrained( @@ -4183,8 +4200,7 @@ class ModelTesterMixin: if fa_kwargs: # flatten features = [ - {"input_ids": i[a.bool()].tolist()} - for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"]) + {"input_ids": i[a.bool()].tolist()} for i, a in zip(dummy_input_ids, dummy_attention_mask) ] # add position_ids + fa_kwargs @@ -4194,55 +4210,48 @@ class ModelTesterMixin: k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items() } else: - # flatten - padfree_inputs_dict = { - k: v[dummy_attention_mask.bool()].unsqueeze(0) - for k, v in inputs_dict.items() - if not k == "attention_mask" - } - # add position_ids - padfree_inputs_dict["position_ids"] = ( + # create packed position_ids + position_ids = ( torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]) .long() .unsqueeze(0) .to(torch_device) ) + padfree_inputs_dict = { + "input_ids": dummy_input_ids[dummy_attention_mask.bool()].unsqueeze(0), + "position_ids": position_ids, + } - # We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path + # We need to do simple forward without cache in order to trigger packed SDPA/flex/eager attention path res_padded = model(**inputs_dict, use_cache=False) res_padfree = model(**padfree_inputs_dict, use_cache=False) - logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] + logits_padded = res_padded.logits[dummy_attention_mask.bool()] logits_padfree = res_padfree.logits[0] - torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0) # acceptable numerical instability tol = torch.finfo(torch.bfloat16).eps torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) - # Mark slow for now as it is failing for all multimodals/non-transformer arch models and a few LLMs - # FIXME @raushan - @slow def test_eager_padding_matches_padding_free_with_position_ids(self): - self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="eager") + self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="eager") - @slow def test_sdpa_padding_matches_padding_free_with_position_ids(self): - self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa") + self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa") @require_flash_attn @require_torch_gpu @mark.flash_attn_test @slow def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): - self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2") + self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2") @require_flash_attn @require_torch_gpu @mark.flash_attn_test @slow def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self): - self.flash_attention_padding_matches_padding_free_with_position_ids( + self.attention_mask_padding_matches_padding_free_with_position_ids( attn_implementation="flash_attention_2", fa_kwargs=True ) @@ -4251,14 +4260,14 @@ class ModelTesterMixin: @mark.flash_attn_3_test @slow def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self): - self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3") + self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3") @require_flash_attn_3 @require_torch_gpu @mark.flash_attn_3_test @slow def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self): - self.flash_attention_padding_matches_padding_free_with_position_ids( + self.attention_mask_padding_matches_padding_free_with_position_ids( attn_implementation="flash_attention_3", fa_kwargs=True )