[Mistral] Add Flash Attention-2 support for mistral (#26464)

* add FA-2 support for mistral

* fixup

* add sliding windows

* fixing few nits

* v1 slicing cache - logits do not match

* add comment

* fix bugs

* more mem efficient

* add warning once

* add warning once

* oops

* fixup

* more comments

* copy

* add safety checker

* fixup

* Update src/transformers/models/mistral/modeling_mistral.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* copied from

* up

* raise when padding side is right

* fixup

* add doc + few minor changes

* fixup

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2023-10-03 13:44:46 +02:00
committed by GitHub
parent 1a2e966cfe
commit ae9a344cce
5 changed files with 435 additions and 5 deletions

View File

@@ -32,6 +32,7 @@ Make sure to follow the installation guide on the repository mentioned above to
We natively support Flash Attention 2 for the following models:
- Llama
- Mistral
- Falcon
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*