[attention] fix test for packed padfree masking (#39582)
* fix most tests * skip a few more tests * address comments * fix chameleon tests * forgot to uncomment * qwen has its own tests with images, rename it as well
This commit is contained in:
committed by
GitHub
parent
565c035a2e
commit
c392d47c9b
@@ -4125,9 +4125,13 @@ class ModelTesterMixin:
|
||||
|
||||
assert not loss.isnan().any()
|
||||
|
||||
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
def attention_mask_padding_matches_padding_free_with_position_ids(
|
||||
self, attn_implementation: str, fa_kwargs: bool = False
|
||||
):
|
||||
"""
|
||||
Tests that the given attention implementation can work with packed sequences and infers the mask
|
||||
from position ids. This test requires the model to use new attention mask API which handles packing.
|
||||
"""
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
@@ -4142,17 +4146,27 @@ class ModelTesterMixin:
|
||||
if attn_implementation != "eager" and not getattr(model_class, support_flag[attn_implementation]):
|
||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||
|
||||
# can't infer if new attn mask API is supported by assume that only model with attention backend support it
|
||||
if not model_class._supports_attention_backend:
|
||||
self.skipTest(f"{model_class.__name__} does not support new attention mask API")
|
||||
|
||||
if model_class._is_stateful: # non-transformer models most probably have no packing support
|
||||
self.skipTest(f"{model_class.__name__} doesn't support packing!")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest("Model is an encoder-decoder")
|
||||
|
||||
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
|
||||
self.skipTest("Model dummy inputs should contain padding in their attention mask")
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
if "input_ids" not in inputs_dict or inputs_dict["input_ids"].ndim != 2:
|
||||
self.skipTest("Model dummy inputs should contain text input ids")
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
dummy_input_ids = inputs_dict["input_ids"]
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input_ids.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
@@ -4164,11 +4178,14 @@ class ModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# ensure left padding, to adapt for some models
|
||||
# Drop all keys except for the minimal set. Hard to manipulate with multimodals/head_mask/etc
|
||||
inputs_dict = {k: v for k, v in inputs_dict.items() if k in ["input_ids", "attention_mask"]}
|
||||
|
||||
# Ensure left padding, to adapt for some models
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
dummy_input_ids[~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
@@ -4183,8 +4200,7 @@ class ModelTesterMixin:
|
||||
if fa_kwargs:
|
||||
# flatten
|
||||
features = [
|
||||
{"input_ids": i[a.bool()].tolist()}
|
||||
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
|
||||
{"input_ids": i[a.bool()].tolist()} for i, a in zip(dummy_input_ids, dummy_attention_mask)
|
||||
]
|
||||
|
||||
# add position_ids + fa_kwargs
|
||||
@@ -4194,55 +4210,48 @@ class ModelTesterMixin:
|
||||
k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()
|
||||
}
|
||||
else:
|
||||
# flatten
|
||||
padfree_inputs_dict = {
|
||||
k: v[dummy_attention_mask.bool()].unsqueeze(0)
|
||||
for k, v in inputs_dict.items()
|
||||
if not k == "attention_mask"
|
||||
}
|
||||
# add position_ids
|
||||
padfree_inputs_dict["position_ids"] = (
|
||||
# create packed position_ids
|
||||
position_ids = (
|
||||
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
|
||||
.long()
|
||||
.unsqueeze(0)
|
||||
.to(torch_device)
|
||||
)
|
||||
padfree_inputs_dict = {
|
||||
"input_ids": dummy_input_ids[dummy_attention_mask.bool()].unsqueeze(0),
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
|
||||
# We need to do simple forward without cache in order to trigger packed SDPA/flex/eager attention path
|
||||
res_padded = model(**inputs_dict, use_cache=False)
|
||||
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padded = res_padded.logits[dummy_attention_mask.bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
# Mark slow for now as it is failing for all multimodals/non-transformer arch models and a few LLMs
|
||||
# FIXME @raushan
|
||||
@slow
|
||||
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
|
||||
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
|
||||
|
||||
@slow
|
||||
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
|
||||
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
|
||||
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
self.attention_mask_padding_matches_padding_free_with_position_ids(
|
||||
attn_implementation="flash_attention_2", fa_kwargs=True
|
||||
)
|
||||
|
||||
@@ -4251,14 +4260,14 @@ class ModelTesterMixin:
|
||||
@mark.flash_attn_3_test
|
||||
@slow
|
||||
def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
|
||||
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
|
||||
|
||||
@require_flash_attn_3
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_3_test
|
||||
@slow
|
||||
def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
self.attention_mask_padding_matches_padding_free_with_position_ids(
|
||||
attn_implementation="flash_attention_3", fa_kwargs=True
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user