From 81233c069c166af033794134bd8888783ac49ebe Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 26 Jul 2024 14:45:55 +0500 Subject: [PATCH] Flash-Attn: fix generation when no attention mask or no pading (#32241) * fix * fix prev test (half of failures) * [run-slow] llama, gemma2 * [run-slow] llama, gemma2 --- .../modeling_flash_attention_utils.py | 6 ++++-- tests/test_modeling_common.py | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 4f0ff8817b..7bb3ee03c0 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -264,9 +264,11 @@ def _flash_attention_forward( ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - # if position_ids is provided and check not all examples (row) contain only 1 sequence, + # if position_ids is provided and check not all examples (row) contain only 1 sequence, and is in pre-fill/training stage # then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all(): + elif ( + position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all() and query_length != 1 + ): batch_size = query_states.size(0) query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( query_states, key_states, value_states, position_ids diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index abe5ddea2c..4e7f9bdf14 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4270,6 +4270,18 @@ class ModelTesterMixin: use_cache=True, ) + # Generate with one batch only to test generation when attention mask will be None + # when real inputs are used, because there is no padding. See issue #32237 for more + dummy_input = dummy_input[:1, ...] + dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...]) + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @require_flash_attn @require_torch_gpu @require_bitsandbytes @@ -4342,6 +4354,8 @@ class ModelTesterMixin: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict: + self.skipTest("Model dummy inputs should contain padding in their attention mask") dummy_input = inputs_dict[model_class.main_input_name] if dummy_input.dtype in [torch.float32, torch.bfloat16]: @@ -4356,7 +4370,6 @@ class ModelTesterMixin: with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - assert 0 in inputs_dict["attention_mask"], "assert padding in testing inputs" # ensure left padding, to adapt for some models if 0 in inputs_dict["attention_mask"][:, -1]: inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)