Flash Attention 2 support for RoCm (#27611)

* support FA2

* fix typo

* fix broken tests

* fix more test errors

* left/right

* fix bug

* more test

* typo

* fix layout flash attention falcon

* do not support this case

* use allclose instead of equal

* fix various bugs with flash attention

* bump

* fix test

* fix mistral

* use skiptest instead of return that may be misleading

* add fix causal arg flash attention

* fix copies

* more explicit comment

* still use self.is_causal

* fix causal argument

* comment

* fixes

* update documentation

* add link

* wrong test

* simplify FA2 RoCm requirements

* update opt

* make flash_attn_uses_top_left_mask attribute private and precise comment

* better error handling

* fix copy & mistral

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/utils/import_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* use is_flash_attn_greater_or_equal_2_10 instead of is_flash_attn_greater_or_equal_210

* fix merge

* simplify

* inline args

---------

Co-authored-by: Felix Marty <felix@hf.co>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
fxmarty
2023-12-04 13:52:17 +01:00
committed by GitHub
parent 4d4febb7aa
commit 1da1302ec8
17 changed files with 253 additions and 51 deletions

View File

@@ -56,13 +56,9 @@ The `generate()` method can be used to generate text using GPT Neo model.
## Combining GPT-Neo and Flash Attention 2 ## Combining GPT-Neo and Flash Attention 2
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature. First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature, and make sure your hardware is compatible with Flash-Attention 2. More details are available [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2) concerning the installation.
```bash Make sure as well to load your model in half-precision (e.g. `torch.float16`).
pip install -U flash-attn --no-build-isolation
```
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)
To load and run a model using Flash Attention 2, refer to the snippet below: To load and run a model using Flash Attention 2, refer to the snippet below:

View File

@@ -38,11 +38,9 @@ FlashAttention-2 is experimental and may change considerably in future versions.
FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request. FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
Before you begin, make sure you have FlashAttention-2 installed (see the [installation](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) guide for more details about prerequisites): Before you begin, make sure you have FlashAttention-2 installed. For NVIDIA GPUs, the library is installable through pip: `pip install flash-attn --no-build-isolation`. We strongly suggest to refer to the [detailed installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).
```bash FlashAttention-2 is also supported on AMD GPUs, with the current support limited to **Instinct MI210 and Instinct MI250**. We strongly suggest to use the following [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.
pip install flash-attn --no-build-isolation
```
To enable FlashAttention-2, add the `use_flash_attention_2` parameter to [`~AutoModelForCausalLM.from_pretrained`]: To enable FlashAttention-2, add the `use_flash_attention_2` parameter to [`~AutoModelForCausalLM.from_pretrained`]:
@@ -62,7 +60,7 @@ model = AutoModelForCausalLM.from_pretrained(
<Tip> <Tip>
FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`, and it only runs on Nvidia GPUs. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2. FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.
</Tip> </Tip>

View File

@@ -1281,17 +1281,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
if not is_flash_attn_2_available(): if not is_flash_attn_2_available():
raise ImportError(
"Flash Attention 2 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
" installing it. Make sure to have at least the version 2.1.0"
)
else:
flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
is_flash_greater_than_2 = flash_attention_version >= version.parse("2.1.0")
if not is_flash_greater_than_2: preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
raise ValueError( install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
f"You need flash_attn package version to be greater or equal than 2.1. Make sure to have that version installed - detected version {flash_attention_version}" if torch.version.cuda:
) if importlib.util.find_spec("flash_attn") is None:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if flash_attention_version < version.parse("2.1.0"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
)
else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
elif torch.version.hip:
if importlib.util.find_spec("flash_attn") is None:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if flash_attention_version < version.parse("2.0.4"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}"
)
else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
_is_bettertransformer = getattr(cls, "use_bettertransformer", False) _is_bettertransformer = getattr(cls, "use_bettertransformer", False)

View File

@@ -34,6 +34,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_accelerate_available, is_accelerate_available,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
) )
from ..auto import AutoModel from ..auto import AutoModel
@@ -214,6 +215,15 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _split_heads(self, tensor, num_heads, attn_head_size): def _split_heads(self, tensor, num_heads, attn_head_size):
""" """
Splits hidden_size dim into attn_head_size and num_heads Splits hidden_size dim into attn_head_size and num_heads
@@ -301,6 +311,12 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
@@ -321,13 +337,13 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output

View File

@@ -42,6 +42,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -294,6 +295,15 @@ class BartFlashAttention2(BartAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -418,6 +428,12 @@ class BartFlashAttention2(BartAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
@@ -438,13 +454,13 @@ class BartFlashAttention2(BartAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output

View File

@@ -46,6 +46,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -269,6 +270,15 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
API of flash attention and deal with padding tokens in case the input contains any of them. API of flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
@@ -363,6 +373,12 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
@@ -383,13 +399,13 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output

View File

@@ -38,6 +38,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
) )
from .configuration_falcon import FalconConfig from .configuration_falcon import FalconConfig
@@ -516,6 +517,15 @@ class FalconFlashAttention2(FalconAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -631,6 +641,12 @@ class FalconFlashAttention2(FalconAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
@@ -651,13 +667,13 @@ class FalconFlashAttention2(FalconAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output

View File

@@ -34,6 +34,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
) )
from .configuration_gpt_bigcode import GPTBigCodeConfig from .configuration_gpt_bigcode import GPTBigCodeConfig
@@ -292,6 +293,15 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
API of flash attention and deal with padding tokens in case the input contains any of them. API of flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -422,6 +432,12 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
@@ -442,13 +458,13 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output

View File

@@ -42,6 +42,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_fx_available, is_torch_fx_available,
logging, logging,
) )
@@ -299,6 +300,15 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
hidden_states, hidden_states,
@@ -400,6 +410,12 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
@@ -420,13 +436,13 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output

View File

@@ -37,6 +37,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -442,6 +443,14 @@ class LlamaFlashAttention2(LlamaAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -491,6 +500,8 @@ class LlamaFlashAttention2(LlamaAttention):
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
@@ -555,6 +566,12 @@ class LlamaFlashAttention2(LlamaAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
@@ -575,13 +592,13 @@ class LlamaFlashAttention2(LlamaAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output

View File

@@ -41,6 +41,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -289,6 +290,15 @@ class MBartFlashAttention2(MBartAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -413,6 +423,12 @@ class MBartFlashAttention2(MBartAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
@@ -433,13 +449,13 @@ class MBartFlashAttention2(MBartAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output

View File

@@ -37,6 +37,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -312,6 +313,15 @@ class MistralFlashAttention2(MistralAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -470,6 +480,12 @@ class MistralFlashAttention2(MistralAttention):
use_sliding_windows (`bool`, *optional*): use_sliding_windows (`bool`, *optional*):
Whether to activate sliding window attention. Whether to activate sliding window attention.
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
@@ -491,7 +507,7 @@ class MistralFlashAttention2(MistralAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
else: else:
attn_output_unpad = flash_attn_varlen_func( attn_output_unpad = flash_attn_varlen_func(
@@ -504,7 +520,7 @@ class MistralFlashAttention2(MistralAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window), window_size=(self.config.sliding_window, self.config.sliding_window),
) )
@@ -517,7 +533,7 @@ class MistralFlashAttention2(MistralAttention):
value_states, value_states,
dropout, dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
@@ -526,7 +542,7 @@ class MistralFlashAttention2(MistralAttention):
value_states, value_states,
dropout, dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window), window_size=(self.config.sliding_window, self.config.sliding_window),
) )

View File

@@ -35,6 +35,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -288,6 +289,15 @@ class OptFlashAttention2(OPTAttention):
attention and deal with padding tokens in case the input contains any of them. attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -404,6 +414,12 @@ class OptFlashAttention2(OPTAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
@@ -424,13 +440,13 @@ class OptFlashAttention2(OPTAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output

View File

@@ -41,6 +41,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -478,6 +479,15 @@ class WhisperFlashAttention2(WhisperAttention):
flash attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -602,6 +612,12 @@ class WhisperFlashAttention2(WhisperAttention):
softmax_scale (`float`, *optional*): softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
@@ -622,13 +638,13 @@ class WhisperFlashAttention2(WhisperAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=self.is_causal, causal=causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
attn_output = flash_attn_func( attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
) )
return attn_output return attn_output

View File

@@ -118,6 +118,7 @@ from .import_utils import (
is_faiss_available, is_faiss_available,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_available, is_flash_attn_available,
is_flash_attn_greater_or_equal_2_10,
is_flax_available, is_flax_available,
is_fsdp_available, is_fsdp_available,
is_ftfy_available, is_ftfy_available,

View File

@@ -71,9 +71,6 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex") _apex_available = _is_package_available("apex")
_bitsandbytes_available = _is_package_available("bitsandbytes") _bitsandbytes_available = _is_package_available("bitsandbytes")
_flash_attn_2_available = _is_package_available("flash_attn") and version.parse(
importlib.metadata.version("flash_attn")
) >= version.parse("2.1.0")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None _bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs") _coloredlogs_available = _is_package_available("coloredlogs")
@@ -608,10 +605,29 @@ def is_flash_attn_2_available():
if not is_torch_available(): if not is_torch_available():
return False return False
if not _is_package_available("flash_attn"):
return False
# Let's add an extra check to see if cuda is available # Let's add an extra check to see if cuda is available
import torch import torch
return _flash_attn_2_available and torch.cuda.is_available() if not torch.cuda.is_available():
return False
if torch.version.cuda:
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
elif torch.version.hip:
# TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4")
else:
return False
def is_flash_attn_greater_or_equal_2_10():
if not _is_package_available("flash_attn"):
return False
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
def is_flash_attn_available(): def is_flash_attn_available():

View File

@@ -3087,7 +3087,7 @@ class ModelTesterMixin:
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
) )
self.assertTrue(torch.equal(out, out_fa)) self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@@ -3130,7 +3130,7 @@ class ModelTesterMixin:
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
) )
self.assertTrue(torch.equal(out, out_fa)) self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu