Flash Attention 2 support for RoCm (#27611)
* support FA2 * fix typo * fix broken tests * fix more test errors * left/right * fix bug * more test * typo * fix layout flash attention falcon * do not support this case * use allclose instead of equal * fix various bugs with flash attention * bump * fix test * fix mistral * use skiptest instead of return that may be misleading * add fix causal arg flash attention * fix copies * more explicit comment * still use self.is_causal * fix causal argument * comment * fixes * update documentation * add link * wrong test * simplify FA2 RoCm requirements * update opt * make flash_attn_uses_top_left_mask attribute private and precise comment * better error handling * fix copy & mistral * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/import_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * use is_flash_attn_greater_or_equal_2_10 instead of is_flash_attn_greater_or_equal_210 * fix merge * simplify * inline args --------- Co-authored-by: Felix Marty <felix@hf.co> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -56,13 +56,9 @@ The `generate()` method can be used to generate text using GPT Neo model.
|
||||
|
||||
## Combining GPT-Neo and Flash Attention 2
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
|
||||
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature, and make sure your hardware is compatible with Flash-Attention 2. More details are available [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2) concerning the installation.
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)
|
||||
Make sure as well to load your model in half-precision (e.g. `torch.float16`).
|
||||
|
||||
To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
|
||||
|
||||
@@ -38,11 +38,9 @@ FlashAttention-2 is experimental and may change considerably in future versions.
|
||||
|
||||
FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
|
||||
|
||||
Before you begin, make sure you have FlashAttention-2 installed (see the [installation](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) guide for more details about prerequisites):
|
||||
Before you begin, make sure you have FlashAttention-2 installed. For NVIDIA GPUs, the library is installable through pip: `pip install flash-attn --no-build-isolation`. We strongly suggest to refer to the [detailed installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
FlashAttention-2 is also supported on AMD GPUs, with the current support limited to **Instinct MI210 and Instinct MI250**. We strongly suggest to use the following [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.
|
||||
|
||||
To enable FlashAttention-2, add the `use_flash_attention_2` parameter to [`~AutoModelForCausalLM.from_pretrained`]:
|
||||
|
||||
@@ -62,7 +60,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
||||
|
||||
<Tip>
|
||||
|
||||
FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`, and it only runs on Nvidia GPUs. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.
|
||||
FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user