Support for Flash Attention 3 (#38972)
* Support `flash_attn_3` Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper - Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...` An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated Based on https://github.com/huggingface/transformers/pull/36190 which has model implementations and examples which could be merged * Add tests for Flash Attention 2 and 3 parity * ci fix * FA2 compatibiity - `_prepare_flash_attention_from_position_ids` ->`prepare_fa2_from_position_ids` - Remove bettertransformer check in Flash Attention 3 - Merge tests - Add licensing * ci fix * Test naming consistency * ci fix * Deprecation warning for `prepare_fa2_from_position_ids` * ci fix
This commit is contained in:
@@ -34,6 +34,7 @@ from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_flash_attn,
|
||||
require_flash_attn_3,
|
||||
require_optimum_quanto,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
@@ -2292,6 +2293,7 @@ class GenerationTesterMixin:
|
||||
support_flag = {
|
||||
"sdpa": "_supports_sdpa",
|
||||
"flash_attention_2": "_supports_flash_attn_2",
|
||||
"flash_attention_3": "_supports_flash_attn_3",
|
||||
}
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@@ -2369,6 +2371,14 @@ class GenerationTesterMixin:
|
||||
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
|
||||
self._test_attention_implementation("flash_attention_2")
|
||||
|
||||
@pytest.mark.flash_attn_3_test
|
||||
@require_flash_attn_3
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_eager_matches_fa3_generate(self):
|
||||
"""Tests that generate has equivalent outputs with FA3 and eager attention implementations."""
|
||||
self._test_attention_implementation("flash_attention_3")
|
||||
|
||||
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
|
||||
internal_batch_size = (
|
||||
|
||||
Reference in New Issue
Block a user