[FA-2] Add Flash Attention to Phi (#27661)

* add FA and modify doc file

* test_flash_attn_2_generate_padding_right test overwritten

* comment

* modify persimmon modeling file

* added speedup graph

* more changes
This commit is contained in:
Susnato Dhar
2023-12-07 12:27:48 +05:30
committed by GitHub
parent 06f561687c
commit f84d85ba67
4 changed files with 345 additions and 10 deletions

View File

@@ -76,7 +76,7 @@ The original code for Phi-1 and Phi-1.5 can be found [here](https://huggingface.
```python
>>> from transformers import PhiForCausalLM, AutoTokenizer
>>> # define the model and tokenzier.
>>> # define the model and tokenizer.
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev")
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
@@ -94,6 +94,46 @@ The original code for Phi-1 and Phi-1.5 can be found [here](https://huggingface.
```
## Combining Phi and Flash Attention 2
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
```bash
pip install -U flash-attn --no-build-isolation
```
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)
To load and run a model using Flash Attention 2, refer to the snippet below:
```python
>>> import torch
>>> from transformers import PhiForCausalLM, AutoTokenizer
>>> # define the model and tokenizer and push the model and tokens to the GPU.
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, use_flash_attention_2=True).to("cuda")
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
>>> # feel free to change the prompt to your liking.
>>> prompt = "If I were an AI that had just achieved"
>>> # apply the tokenizer.
>>> tokens = tokenizer(prompt, return_tensors="pt").to("cuda")
>>> # use the model to generate new tokens.
>>> generated_output = model.generate(**tokens, use_cache=True, max_new_tokens=10)
>>> tokenizer.batch_decode(generated_output)[0]
'If I were an AI that had just achieved a breakthrough in machine learning, I would be thrilled'
```
### Expected speedups
Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `susnato/phi-1_dev` checkpoint and the Flash Attention 2 version of the model using a sequence length of 2048.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/phi_1_speedup_plot.jpg">
</div>
## PhiConfig
[[autodoc]] PhiConfig