Flash-Attn: fix generation when no attention mask or no pading (#32241)
* fix * fix prev test (half of failures) * [run-slow] llama, gemma2 * [run-slow] llama, gemma2
This commit is contained in:
committed by
GitHub
parent
27c7f971c0
commit
81233c069c
@@ -264,9 +264,11 @@ def _flash_attention_forward(
|
|||||||
)
|
)
|
||||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||||
|
|
||||||
# if position_ids is provided and check not all examples (row) contain only 1 sequence,
|
# if position_ids is provided and check not all examples (row) contain only 1 sequence, and is in pre-fill/training stage
|
||||||
# then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
# then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||||
elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all():
|
elif (
|
||||||
|
position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all() and query_length != 1
|
||||||
|
):
|
||||||
batch_size = query_states.size(0)
|
batch_size = query_states.size(0)
|
||||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
||||||
query_states, key_states, value_states, position_ids
|
query_states, key_states, value_states, position_ids
|
||||||
|
|||||||
@@ -4270,6 +4270,18 @@ class ModelTesterMixin:
|
|||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Generate with one batch only to test generation when attention mask will be None
|
||||||
|
# when real inputs are used, because there is no padding. See issue #32237 for more
|
||||||
|
dummy_input = dummy_input[:1, ...]
|
||||||
|
dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...])
|
||||||
|
_ = model.generate(
|
||||||
|
dummy_input,
|
||||||
|
attention_mask=dummy_attention_mask,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
do_sample=False,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@@ -4342,6 +4354,8 @@ class ModelTesterMixin:
|
|||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
|
||||||
|
self.skipTest("Model dummy inputs should contain padding in their attention mask")
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
dummy_input = inputs_dict[model_class.main_input_name]
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||||
@@ -4356,7 +4370,6 @@ class ModelTesterMixin:
|
|||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
assert 0 in inputs_dict["attention_mask"], "assert padding in testing inputs"
|
|
||||||
# ensure left padding, to adapt for some models
|
# ensure left padding, to adapt for some models
|
||||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user