|
|
|
|
@@ -56,9 +56,24 @@ FlashAttention-2 is currently supported for the following architectures:
|
|
|
|
|
|
|
|
|
|
You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
|
|
|
|
|
|
|
|
|
|
Before you begin, make sure you have FlashAttention-2 installed. For NVIDIA GPUs, the library is installable through pip: `pip install flash-attn --no-build-isolation`. We strongly suggest to refer to the [detailed installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).
|
|
|
|
|
Before you begin, make sure you have FlashAttention-2 installed.
|
|
|
|
|
|
|
|
|
|
FlashAttention-2 is also supported on AMD GPUs, with the current support limited to **Instinct MI210 and Instinct MI250**. We strongly suggest to use the following [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.
|
|
|
|
|
<hfoptions id="install">
|
|
|
|
|
<hfoption id="NVIDIA">
|
|
|
|
|
|
|
|
|
|
```bash
|
|
|
|
|
pip install flash-attn --no-build-isolation
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
We strongly suggest referring to the detailed [installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) to learn more about supported hardware and data types!
|
|
|
|
|
|
|
|
|
|
</hfoption>
|
|
|
|
|
<hfoption id="AMD">
|
|
|
|
|
|
|
|
|
|
FlashAttention-2 is also supported on AMD GPUs and current support is limited to **Instinct MI210** and **Instinct MI250**. We strongly suggest using this [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.
|
|
|
|
|
|
|
|
|
|
</hfoption>
|
|
|
|
|
</hfoptions>
|
|
|
|
|
|
|
|
|
|
To enable FlashAttention-2, pass the argument `attn_implementation="flash_attention_2"` to [`~AutoModelForCausalLM.from_pretrained`]:
|
|
|
|
|
|
|
|
|
|
@@ -80,7 +95,9 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.
|
|
|
|
|
|
|
|
|
|
Note that `use_flash_attention_2=True` can also be used to enable Flash Attention 2, but is deprecated in favor of `attn_implementation="flash_attention_2"`.
|
|
|
|
|
<br>
|
|
|
|
|
|
|
|
|
|
You can also set `use_flash_attention_2=True` to enable FlashAttention-2 but it is deprecated in favor of `attn_implementation="flash_attention_2"`.
|
|
|
|
|
|
|
|
|
|
</Tip>
|
|
|
|
|
|
|
|
|
|
@@ -144,11 +161,11 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
|
|
|
|
|
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png">
|
|
|
|
|
</div>
|
|
|
|
|
|
|
|
|
|
## FlashAttention and memory-efficient attention through PyTorch's scaled_dot_product_attention
|
|
|
|
|
## PyTorch scaled dot product attention
|
|
|
|
|
|
|
|
|
|
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers, and is used by default for `torch>=2.1.1` when an implementation is available.
|
|
|
|
|
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available.
|
|
|
|
|
|
|
|
|
|
For now, Transformers supports inference and training through SDPA for the following architectures:
|
|
|
|
|
For now, Transformers supports SDPA inference and training for the following architectures:
|
|
|
|
|
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
|
|
|
|
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
|
|
|
|
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
|
|
|
|
@@ -156,9 +173,13 @@ For now, Transformers supports inference and training through SDPA for the follo
|
|
|
|
|
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
|
|
|
|
|
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
|
|
|
|
|
|
|
|
|
Note that FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type before using it.
|
|
|
|
|
<Tip>
|
|
|
|
|
|
|
|
|
|
By default, `torch.nn.functional.scaled_dot_product_attention` selects the most performant kernel available, but to check whether a backend is available in a given setting (hardware, problem size), you can use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
|
|
|
|
|
FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first.
|
|
|
|
|
|
|
|
|
|
</Tip>
|
|
|
|
|
|
|
|
|
|
By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
|
|
|
|
|
|
|
|
|
|
```diff
|
|
|
|
|
import torch
|
|
|
|
|
@@ -178,7 +199,7 @@ inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
|
|
|
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
If you see a bug with the traceback below, try using nightly version of PyTorch which may have broader coverage for FlashAttention:
|
|
|
|
|
If you see a bug with the traceback below, try using the nightly version of PyTorch which may have broader coverage for FlashAttention:
|
|
|
|
|
|
|
|
|
|
```bash
|
|
|
|
|
RuntimeError: No available kernel. Aborting execution.
|
|
|
|
|
@@ -191,11 +212,10 @@ pip3 install -U --pre torch torchvision torchaudio --index-url https://download.
|
|
|
|
|
|
|
|
|
|
<Tip warning={true}>
|
|
|
|
|
|
|
|
|
|
Part of BetterTransformer features are being upstreamed in Transformers, with native `torch.nn.scaled_dot_product_attention` default support. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to support natively SDPA in Transformers.
|
|
|
|
|
Some BetterTransformer features are being upstreamed to Transformers with default support for native `torch.nn.scaled_dot_product_attention`. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to natively support SDPA in Transformers.
|
|
|
|
|
|
|
|
|
|
</Tip>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<Tip>
|
|
|
|
|
|
|
|
|
|
Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post.
|
|
|
|
|
|