[qwen2 vl] fix packing with all attentions (#39447)
* fix qwen2 vl packing in FA2 * why? delete! * qwen2-5-vl seems to work now * update * fix tests * start by adapting FA2 tests * add similar tests for sdpa/eager * address comments * why is this even in conditional model and not base model?
This commit is contained in:
committed by
GitHub
parent
e42681b48b
commit
344012b3a6
@@ -332,6 +332,92 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
self, attn_implementation: str, fa_kwargs: bool = False
|
||||
):
|
||||
max_new_tokens = 30
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# flatten
|
||||
padfree_inputs_dict = {
|
||||
"input_features": inputs_dict["input_features"],
|
||||
"feature_attention_mask": inputs_dict["feature_attention_mask"],
|
||||
"pixel_values": inputs_dict["pixel_values"],
|
||||
"image_grid_thw": inputs_dict["image_grid_thw"],
|
||||
"input_ids": inputs_dict["input_ids"][dummy_attention_mask.bool()].unsqueeze(0),
|
||||
}
|
||||
|
||||
# add position_ids
|
||||
vision_position_ids, deltas = model.get_rope_index(
|
||||
input_ids=inputs_dict["input_ids"],
|
||||
image_grid_thw=inputs_dict["image_grid_thw"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
audio_seqlens=torch.sum(inputs_dict["feature_attention_mask"], dim=1),
|
||||
) # [3, bs, padded-seq-len]
|
||||
vision_padfree_positions = vision_position_ids[:, dummy_attention_mask.bool()].view(
|
||||
3, -1
|
||||
) # [3, bs*padfree-len]
|
||||
text_padfree_positions = torch.cat(
|
||||
[torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]
|
||||
) # [1, bs*padfree-len]
|
||||
text_padfree_positions = text_padfree_positions.long().unsqueeze(0).to(torch_device)
|
||||
padfree_inputs_dict["position_ids"] = torch.cat([text_padfree_positions, vision_padfree_positions])[
|
||||
:, None, :
|
||||
]
|
||||
|
||||
if fa_kwargs:
|
||||
cu_seq_lens = [0] + dummy_attention_mask.sum(1).tolist()
|
||||
cu_seq_lens = torch.tensor(cu_seq_lens, device=torch_device)
|
||||
max_length = cu_seq_lens.diff().max().item()
|
||||
padfree_inputs_dict.update(
|
||||
{
|
||||
"cu_seq_lens_q": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||
"cu_seq_lens_k": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||
"max_length_q": max_length,
|
||||
"max_length_k": max_length,
|
||||
}
|
||||
)
|
||||
|
||||
res_padded = model(**inputs_dict, use_cache=False)
|
||||
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
@unittest.skip("Cannot do contrastive generation, has custom `generate()`")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@@ -325,6 +325,89 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
)
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
self, attn_implementation: str, fa_kwargs: bool = False
|
||||
):
|
||||
max_new_tokens = 30
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# flatten
|
||||
padfree_inputs_dict = {
|
||||
"pixel_values": inputs_dict["pixel_values"],
|
||||
"image_grid_thw": inputs_dict["image_grid_thw"],
|
||||
"input_ids": inputs_dict["input_ids"][dummy_attention_mask.bool()].unsqueeze(0),
|
||||
}
|
||||
|
||||
# add position_ids
|
||||
vision_position_ids, deltas = model.model.get_rope_index(
|
||||
input_ids=inputs_dict["input_ids"],
|
||||
image_grid_thw=inputs_dict["image_grid_thw"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
) # [3, bs, padded-seq-len]
|
||||
vision_padfree_positions = vision_position_ids[:, dummy_attention_mask.bool()].view(
|
||||
3, -1
|
||||
) # [3, bs*padfree-len]
|
||||
text_padfree_positions = torch.cat(
|
||||
[torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]
|
||||
) # [1, bs*padfree-len]
|
||||
text_padfree_positions = text_padfree_positions.long().unsqueeze(0).to(torch_device)
|
||||
padfree_inputs_dict["position_ids"] = torch.cat([text_padfree_positions, vision_padfree_positions])[
|
||||
:, None, :
|
||||
]
|
||||
|
||||
if fa_kwargs:
|
||||
cu_seq_lens = [0] + dummy_attention_mask.sum(1).tolist()
|
||||
cu_seq_lens = torch.tensor(cu_seq_lens, device=torch_device)
|
||||
max_length = cu_seq_lens.diff().max().item()
|
||||
padfree_inputs_dict.update(
|
||||
{
|
||||
"cu_seq_lens_q": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||
"cu_seq_lens_k": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||
"max_length_q": max_length,
|
||||
"max_length_k": max_length,
|
||||
}
|
||||
)
|
||||
|
||||
res_padded = model(**inputs_dict, use_cache=False)
|
||||
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
@@ -168,6 +169,7 @@ class Qwen2VLVisionText2TextModelTester:
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
|
||||
input_ids[:, -1] = self.pad_token_id
|
||||
attention_mask[:, -1] = 0
|
||||
input_ids[input_ids == self.video_token_id] = self.pad_token_id
|
||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
|
||||
@@ -281,6 +283,90 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4
|
||||
)
|
||||
|
||||
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
self, attn_implementation: str, fa_kwargs: bool = False
|
||||
):
|
||||
max_new_tokens = 30
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# flatten
|
||||
padfree_inputs_dict = {
|
||||
"pixel_values": inputs_dict["pixel_values"],
|
||||
"image_grid_thw": inputs_dict["image_grid_thw"],
|
||||
"input_ids": inputs_dict["input_ids"][dummy_attention_mask.bool()].unsqueeze(0),
|
||||
}
|
||||
|
||||
# add position_ids
|
||||
vision_position_ids, deltas = model.model.get_rope_index(
|
||||
input_ids=inputs_dict["input_ids"],
|
||||
image_grid_thw=inputs_dict["image_grid_thw"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
) # [3, bs, padded-seq-len]
|
||||
vision_padfree_positions = vision_position_ids[:, dummy_attention_mask.bool()].view(
|
||||
3, -1
|
||||
) # [3, bs*padfree-len]
|
||||
text_padfree_positions = torch.cat(
|
||||
[torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]
|
||||
) # [1, bs*padfree-len]
|
||||
text_padfree_positions = text_padfree_positions.long().unsqueeze(0).to(torch_device)
|
||||
padfree_inputs_dict["position_ids"] = torch.cat([text_padfree_positions, vision_padfree_positions])[
|
||||
:, None, :
|
||||
]
|
||||
|
||||
if fa_kwargs:
|
||||
cu_seq_lens = [0] + dummy_attention_mask.sum(1).tolist()
|
||||
cu_seq_lens = torch.tensor(cu_seq_lens, device=torch_device)
|
||||
max_length = cu_seq_lens.diff().max().item()
|
||||
padfree_inputs_dict.update(
|
||||
{
|
||||
"cu_seq_lens_q": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||
"cu_seq_lens_k": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||
"max_length_q": max_length,
|
||||
"max_length_k": max_length,
|
||||
}
|
||||
)
|
||||
|
||||
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
|
||||
res_padded = model(**inputs_dict, use_cache=False)
|
||||
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@@ -4129,13 +4129,14 @@ class ModelTesterMixin:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
max_new_tokens = 30
|
||||
support_flag = {
|
||||
"sdpa": "_supports_sdpa",
|
||||
"flash_attention_2": "_supports_flash_attn",
|
||||
"flash_attention_3": "_supports_flash_attn",
|
||||
}
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not (
|
||||
model_class._supports_flash_attn_2
|
||||
if attn_implementation == "flash_attention_2"
|
||||
else model_class._supports_flash_attn_3
|
||||
):
|
||||
if not getattr(model_class, support_flag[attn_implementation]):
|
||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -4204,8 +4205,9 @@ class ModelTesterMixin:
|
||||
.to(torch_device)
|
||||
)
|
||||
|
||||
res_padded = model(**inputs_dict)
|
||||
res_padfree = model(**padfree_inputs_dict)
|
||||
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
|
||||
res_padded = model(**inputs_dict, use_cache=False)
|
||||
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
@@ -4215,6 +4217,16 @@ class ModelTesterMixin:
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
# Mark slow for now as it is failing for all multimodals/non-transformer arch models and a few LLMs
|
||||
# FIXME @raushan
|
||||
@slow
|
||||
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
|
||||
|
||||
@slow
|
||||
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
|
||||
Reference in New Issue
Block a user