Support for Flash Attention 3 (#38972)
* Support `flash_attn_3` Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper - Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...` An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated Based on https://github.com/huggingface/transformers/pull/36190 which has model implementations and examples which could be merged * Add tests for Flash Attention 2 and 3 parity * ci fix * FA2 compatibiity - `_prepare_flash_attention_from_position_ids` ->`prepare_fa2_from_position_ids` - Remove bettertransformer check in Flash Attention 3 - Merge tests - Add licensing * ci fix * Test naming consistency * ci fix * Deprecation warning for `prepare_fa2_from_position_ids` * ci fix
This commit is contained in:
@@ -77,6 +77,7 @@ from transformers.utils import (
|
||||
)
|
||||
from transformers.utils.import_utils import (
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_3_available,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_npu_available,
|
||||
@@ -676,6 +677,9 @@ class ModelUtilsTest(TestCasePlus):
|
||||
if is_flash_attn_available():
|
||||
attn_implementation_available.append("flash_attention_2")
|
||||
|
||||
if is_flash_attn_3_available():
|
||||
attn_implementation_available.append("flash_attention_3")
|
||||
|
||||
for requested_attn_implementation in attn_implementation_available:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
TINY_MISTRAL, attn_implementation=requested_attn_implementation
|
||||
@@ -700,6 +704,9 @@ class ModelUtilsTest(TestCasePlus):
|
||||
if is_flash_attn_available():
|
||||
attn_implementation_available.append("flash_attention_2")
|
||||
|
||||
if is_flash_attn_3_available():
|
||||
attn_implementation_available.append("flash_attention_3")
|
||||
|
||||
for requested_attn_implementation in attn_implementation_available:
|
||||
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
|
||||
# Ensure the config was set correctly
|
||||
|
||||
Reference in New Issue
Block a user