[attention] fix test for packed padfree masking (#39582)
* fix most tests * skip a few more tests * address comments * fix chameleon tests * forgot to uncomment * qwen has its own tests with images, rename it as well
This commit is contained in:
committed by
GitHub
parent
565c035a2e
commit
c392d47c9b
@@ -408,69 +408,31 @@ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.input_embeds_layer = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, cache_position=None, **kwargs):
|
||||
# Overwritten -- bark has a model-specific hack
|
||||
input_embeds = kwargs.get("input_embeds", None)
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
input_embeds=None,
|
||||
past_key_values=None,
|
||||
position_ids=None,
|
||||
use_cache=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- bark uses `input_embeds` not `inputS_embeds`
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if cache_position[0] != 0:
|
||||
# Omit tokens covered by past_key_values
|
||||
seq_len = input_ids.shape[1]
|
||||
past_length = past_key_values.get_seq_length()
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = input_ids.shape[1] - 1
|
||||
|
||||
input_ids = input_ids[:, remove_prefix_length:]
|
||||
|
||||
# input_embeds have already been used and is not required anymore
|
||||
input_embeds = None
|
||||
else:
|
||||
if input_embeds is not None and kwargs.get("use_cache"):
|
||||
seq_len = input_embeds.shape[1]
|
||||
else:
|
||||
seq_len = input_ids.shape[1]
|
||||
|
||||
# ensure that attention_mask and position_ids shapes are aligned with the weird Bark hack of reducing
|
||||
# sequence length on the first forward pass
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :seq_len]
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids[:, :seq_len]
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
else:
|
||||
position_ids = None
|
||||
|
||||
if input_embeds is not None and kwargs.get("use_cache"):
|
||||
return {
|
||||
"input_ids": None,
|
||||
"input_embeds": input_embeds,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=input_embeds,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
model_inputs["input_embeds"] = model_inputs.pop("inputs_embeds", None)
|
||||
return model_inputs
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
@@ -546,7 +508,7 @@ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
|
||||
return_legacy_cache = True
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
|
||||
past_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values
|
||||
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
|
||||
|
||||
@@ -25,7 +25,7 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
@@ -35,19 +35,12 @@ from ...utils import (
|
||||
TransformersKwargs,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_torch_flex_attn_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -353,15 +346,8 @@ class ChameleonAttention(nn.Module):
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
@@ -942,7 +928,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
else:
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
|
||||
n_image_tokens_in_text = (special_image_mask).sum()
|
||||
n_image_tokens_in_text = special_image_mask.sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
image_embeds = self.get_image_features(pixel_values)
|
||||
@@ -966,8 +952,13 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
@@ -1015,131 +1006,6 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
|
||||
@@ -270,7 +270,7 @@ class ChameleonVision2SeqModelTester(ChameleonModelTester):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||
input_ids[:, : self.image_seq_length] = self.image_token_id
|
||||
attention_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
|
||||
pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size])
|
||||
|
||||
config = self.get_config()
|
||||
@@ -325,6 +325,14 @@ class ChameleonVision2SeqModelTest(ModelTesterMixin, GenerationTesterMixin, unit
|
||||
def test_model_is_small(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Chameleon applies key/query norm which doesn't work with packing")
|
||||
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Chameleon applies key/query norm which doesn't work with packing")
|
||||
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
|
||||
@@ -89,7 +89,7 @@ class Emu3Text2TextModelTester:
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
@@ -234,9 +234,9 @@ class Emu3Vision2TextModelTester:
|
||||
config = self.get_config()
|
||||
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size)
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||
input_ids[:, : self.image_seq_length] = self.image_token_id
|
||||
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
|
||||
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
|
||||
@@ -214,6 +214,14 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def test_generate_continue_from_inputs_embeds():
|
||||
pass
|
||||
|
||||
@unittest.skip("Persimmon backbone applies key/query norm which doesn't work with packing")
|
||||
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Persimmon backbone applies key/query norm which doesn't work with packing")
|
||||
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
|
||||
@@ -143,6 +143,14 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
|
||||
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
|
||||
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
|
||||
class Gemma3Vision2TextModelTester:
|
||||
def __init__(
|
||||
|
||||
@@ -465,6 +465,14 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
):
|
||||
pass
|
||||
|
||||
@unittest.skip("KOSMOS-2 doesn't support padding")
|
||||
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("KOSMOS-2 doesn't support padding")
|
||||
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_left_padding_compatibility(self):
|
||||
# Overwrite because Kosmos-2 need to padd pixel values and pad image-attn-mask
|
||||
|
||||
@@ -152,9 +152,10 @@ class LlavaVisionText2TextModelTester:
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
input_ids[input_ids == config.image_token_index] = self.pad_token_id
|
||||
input_ids[:, : self.num_image_tokens] = config.image_token_index
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
|
||||
@@ -214,27 +214,27 @@ class MiniMaxModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
self._check_past_key_values_for_generate(batch_size, past_kv, seq_length, config)
|
||||
|
||||
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
|
||||
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
|
||||
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
|
||||
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
|
||||
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
|
||||
def test_assisted_decoding_matches_greedy_search_0_random(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
|
||||
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
|
||||
def test_assisted_decoding_matches_greedy_search_1_same(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="MiniMaxCache doesnot support `crop()` method")
|
||||
@unittest.skip(reason="MiniMaxCache does not support `crop()` method")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@@ -242,6 +242,14 @@ class MiniMaxModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
def test_attention_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("MiniMax is special")
|
||||
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("MiniMax is special")
|
||||
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
|
||||
@@ -290,6 +290,14 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Paligemma position ids are 1 indexed")
|
||||
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Paloigemma position ids are 1 indexed")
|
||||
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
def test_attention_mask_with_token_types(self):
|
||||
"""Test that attention masking works correctly both with and without token type IDs."""
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -321,3 +321,11 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
@is_flaky
|
||||
def test_generate_compile_model_forward(self):
|
||||
super().test_generate_compile_model_forward()
|
||||
|
||||
@unittest.skip("Paligemma position ids are 1 indexed")
|
||||
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Paligemma position ids are 1 indexed")
|
||||
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@@ -76,6 +76,14 @@ class PersimmonModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
|
||||
@unittest.skip("Persimmon applies key/query norm which doesn't work with packing")
|
||||
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Persimmon applies key/query norm which doesn't work with packing")
|
||||
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class PersimmonIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -332,7 +332,7 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
def attention_mask_padding_matches_padding_free_with_position_ids(
|
||||
self, attn_implementation: str, fa_kwargs: bool = False
|
||||
):
|
||||
max_new_tokens = 30
|
||||
|
||||
@@ -325,7 +325,7 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
)
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
def attention_mask_padding_matches_padding_free_with_position_ids(
|
||||
self, attn_implementation: str, fa_kwargs: bool = False
|
||||
):
|
||||
max_new_tokens = 30
|
||||
|
||||
@@ -283,7 +283,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4
|
||||
)
|
||||
|
||||
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
def attention_mask_padding_matches_padding_free_with_position_ids(
|
||||
self, attn_implementation: str, fa_kwargs: bool = False
|
||||
):
|
||||
max_new_tokens = 30
|
||||
|
||||
@@ -60,6 +60,7 @@ class VoxtralModelTester:
|
||||
"use_mrope": False,
|
||||
"vocab_size": 99,
|
||||
"head_dim": 8,
|
||||
"pad_token_id": 0,
|
||||
},
|
||||
is_training=True,
|
||||
audio_config={
|
||||
|
||||
@@ -4125,9 +4125,13 @@ class ModelTesterMixin:
|
||||
|
||||
assert not loss.isnan().any()
|
||||
|
||||
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
def attention_mask_padding_matches_padding_free_with_position_ids(
|
||||
self, attn_implementation: str, fa_kwargs: bool = False
|
||||
):
|
||||
"""
|
||||
Tests that the given attention implementation can work with packed sequences and infers the mask
|
||||
from position ids. This test requires the model to use new attention mask API which handles packing.
|
||||
"""
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
@@ -4142,17 +4146,27 @@ class ModelTesterMixin:
|
||||
if attn_implementation != "eager" and not getattr(model_class, support_flag[attn_implementation]):
|
||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||
|
||||
# can't infer if new attn mask API is supported by assume that only model with attention backend support it
|
||||
if not model_class._supports_attention_backend:
|
||||
self.skipTest(f"{model_class.__name__} does not support new attention mask API")
|
||||
|
||||
if model_class._is_stateful: # non-transformer models most probably have no packing support
|
||||
self.skipTest(f"{model_class.__name__} doesn't support packing!")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest("Model is an encoder-decoder")
|
||||
|
||||
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
|
||||
self.skipTest("Model dummy inputs should contain padding in their attention mask")
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
if "input_ids" not in inputs_dict or inputs_dict["input_ids"].ndim != 2:
|
||||
self.skipTest("Model dummy inputs should contain text input ids")
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
dummy_input_ids = inputs_dict["input_ids"]
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input_ids.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
@@ -4164,11 +4178,14 @@ class ModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# ensure left padding, to adapt for some models
|
||||
# Drop all keys except for the minimal set. Hard to manipulate with multimodals/head_mask/etc
|
||||
inputs_dict = {k: v for k, v in inputs_dict.items() if k in ["input_ids", "attention_mask"]}
|
||||
|
||||
# Ensure left padding, to adapt for some models
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
dummy_input_ids[~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
@@ -4183,8 +4200,7 @@ class ModelTesterMixin:
|
||||
if fa_kwargs:
|
||||
# flatten
|
||||
features = [
|
||||
{"input_ids": i[a.bool()].tolist()}
|
||||
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
|
||||
{"input_ids": i[a.bool()].tolist()} for i, a in zip(dummy_input_ids, dummy_attention_mask)
|
||||
]
|
||||
|
||||
# add position_ids + fa_kwargs
|
||||
@@ -4194,55 +4210,48 @@ class ModelTesterMixin:
|
||||
k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()
|
||||
}
|
||||
else:
|
||||
# flatten
|
||||
padfree_inputs_dict = {
|
||||
k: v[dummy_attention_mask.bool()].unsqueeze(0)
|
||||
for k, v in inputs_dict.items()
|
||||
if not k == "attention_mask"
|
||||
}
|
||||
# add position_ids
|
||||
padfree_inputs_dict["position_ids"] = (
|
||||
# create packed position_ids
|
||||
position_ids = (
|
||||
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
|
||||
.long()
|
||||
.unsqueeze(0)
|
||||
.to(torch_device)
|
||||
)
|
||||
padfree_inputs_dict = {
|
||||
"input_ids": dummy_input_ids[dummy_attention_mask.bool()].unsqueeze(0),
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
|
||||
# We need to do simple forward without cache in order to trigger packed SDPA/flex/eager attention path
|
||||
res_padded = model(**inputs_dict, use_cache=False)
|
||||
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padded = res_padded.logits[dummy_attention_mask.bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
# Mark slow for now as it is failing for all multimodals/non-transformer arch models and a few LLMs
|
||||
# FIXME @raushan
|
||||
@slow
|
||||
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
|
||||
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
|
||||
|
||||
@slow
|
||||
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
|
||||
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
|
||||
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
self.attention_mask_padding_matches_padding_free_with_position_ids(
|
||||
attn_implementation="flash_attention_2", fa_kwargs=True
|
||||
)
|
||||
|
||||
@@ -4251,14 +4260,14 @@ class ModelTesterMixin:
|
||||
@mark.flash_attn_3_test
|
||||
@slow
|
||||
def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
|
||||
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
|
||||
|
||||
@require_flash_attn_3
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_3_test
|
||||
@slow
|
||||
def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
self.attention_mask_padding_matches_padding_free_with_position_ids(
|
||||
attn_implementation="flash_attention_3", fa_kwargs=True
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user