[core ] Integrate Flash attention 2 in most used models (#25598)
* v1 * oops * working v1 * fixup * add some TODOs * fixup * padding support + try with module replacement * nit * alternative design * oops * add `use_cache` support for llama * v1 falcon * nit * a bit of refactor * nit * nits nits * add v1 padding support falcon (even though it seemed to work before) * nit * falcon works * fixup * v1 tests * nit * fix generation llama flash * update tests * fix tests + nits * fix copies * fix nit * test- padding mask * stype * add more mem efficient support * Update src/transformers/modeling_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fixup * nit * fixup * remove it from config when saving * fixup * revert docstring * add more checks * use values * oops * new version * fixup * add same trick for falcon * nit * add another test * change tests * fix issues with GC and also falcon * fixup * oops * Update src/transformers/models/falcon/modeling_falcon.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add init_rope * updates * fix copies * fixup * fixup * more clarification * fixup * right padding tests * add docs * add FA in docker image * more clarifications * add some figures * add todo * rectify comment * Change to FA2 * Update docs/source/en/perf_infer_gpu_one.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * split in two lines * change test name * add more tests * some clean up * remove `rearrange` deps * add more docs * revert changes on dockerfile * Revert "revert changes on dockerfile" This reverts commit 8d72a66b4b9b771abc3f15a9b9506b4246d62d8e. * revert changes on dockerfile * Apply suggestions from code review Co-authored-by: Lysandre Debut <hi@lysand.re> * address some comments * docs * use inheritance * Update src/transformers/testing_utils.py Co-authored-by: Lysandre Debut <hi@lysand.re> * fixup * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/modeling_utils.py * final comments * clean up * style * add cast + warning for PEFT models * fixup --------- Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -17,6 +17,154 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
In addition to this guide, relevant information can be found as well in [the guide for training on a single GPU](perf_train_gpu_one) and [the guide for inference on CPUs](perf_infer_cpu).
|
||||
|
||||
## Flash Attention 2
|
||||
|
||||
<Tip>
|
||||
|
||||
Note that this feature is experimental and might considerably change in future versions. For instance, the Flash Attention 2 API might migrate to `BetterTransformer` API in the near future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Flash Attention 2 can considerably speed up transformer-based models' training and inference speed. Flash Attention 2 has been introduced in the [official Flash Attention repository](https://github.com/Dao-AILab/flash-attention) by Tri Dao et al. The scientific paper on Flash Attention can be found [here](https://arxiv.org/abs/2205.14135).
|
||||
|
||||
Make sure to follow the installation guide on the repository mentioned above to properly install Flash Attention 2. Once that package is installed, you can benefit from this feature.
|
||||
|
||||
We natively support Flash Attention 2 for the following models:
|
||||
|
||||
- Llama
|
||||
- Falcon
|
||||
|
||||
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*
|
||||
|
||||
<Tip>
|
||||
|
||||
Flash Attention 2 can only be used when the models' dtype is `fp16` or `bf16` and runs only on NVIDIA-GPU devices. Make sure to cast your model to the appropriate dtype and load them on a supported device before using that feature.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Quick usage
|
||||
|
||||
To enable Flash Attention 2 in your model, add `use_flash_attention_2` in the `from_pretrained` arguments:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
model_id = "tiiuae/falcon-7b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
```
|
||||
|
||||
And use it for generation or fine-tuning.
|
||||
|
||||
### Expected speedups
|
||||
|
||||
You can benefit from considerable speedups for fine-tuning and inference, especially for long sequences. However, since Flash Attention does not support computing attention scores with padding tokens under the hood, we must manually pad / unpad the attention scores for batched inference when the sequence contains padding tokens. This leads to a significant slowdown for batched generations with padding tokens.
|
||||
|
||||
To overcome this, one should use Flash Attention without padding tokens in the sequence for training (e.g., by packing a dataset, i.e., concatenating sequences until reaching the maximum sequence length. An example is provided [here](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py#L516).
|
||||
|
||||
Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes without padding tokens:
|
||||
|
||||
Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes, without padding tokens:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/falcon-7b-inference-large-seqlen.png">
|
||||
</div>
|
||||
|
||||
Below is the expected speedup you can get for a simple forward pass on [`meta-llama/Llama-7b-hf`](https://hf.co/meta-llama/Llama-7b-hf) with a sequence length of 4096 and various batch sizes, without padding tokens:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-7b-inference-large-seqlen.png">
|
||||
</div>
|
||||
|
||||
For sequences with padding tokens (training with padding tokens or generating with padding tokens), we need to unpad / pad the input sequences to compute correctly the attention scores. For relatively small sequence length, on pure forward pass, this creates an overhead leading to a small speedup (below 30% of the input has been filled with padding tokens).
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-small-seqlen-padding.png">
|
||||
</div>
|
||||
|
||||
But for large sequence length you can benefit from interesting speedup for pure inference (also training)
|
||||
|
||||
Note that Flash Attention makes the attention computation more memory efficient, meaning you can train with much larger sequence lengths without facing CUDA OOM issues. It can lead up to memory reduction up to 20 for large sequence length. Check out [the official flash attention repository](https://github.com/Dao-AILab/flash-attention) for more details.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png">
|
||||
</div>
|
||||
|
||||
|
||||
### Advanced usage
|
||||
|
||||
You can combine this feature with many exisiting feature for model optimization. Check out few examples below:
|
||||
|
||||
### Combining Flash Attention 2 and 8-bit models
|
||||
|
||||
You can combine this feature together with 8-bit quantization:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
model_id = "tiiuae/falcon-7b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
load_in_8bit=True,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
```
|
||||
|
||||
### Combining Flash Attention 2 and 4-bit models
|
||||
|
||||
You can combine this feature together with 4-bit quantization:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
model_id = "tiiuae/falcon-7b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
load_in_4bit=True,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
```
|
||||
|
||||
### Combining Flash Attention 2 and PEFT
|
||||
|
||||
You can combine this feature together with PEFT for training adapters using Flash Attention 2 under the hood:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
|
||||
from peft import LoraConfig
|
||||
|
||||
model_id = "tiiuae/falcon-7b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
load_in_4bit=True,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
... # train your model
|
||||
```
|
||||
|
||||
## BetterTransformer
|
||||
|
||||
[BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood.
|
||||
|
||||
Reference in New Issue
Block a user