Adding Flash Attention 2 Support for GPT2 (#29226)
* First commit to add flash attention 2 for GPT-2 * more improvements * Make GPT2 pass tests and fixed Decison Transformers copies * Fixed missing arg * fix copies * Added expected speedup * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Added test * Fixed attn attribute * Update docs/source/en/model_doc/gpt2.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/gpt2.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update Decision transformer attentions * More updates * Passing tests * Fix copies * Fix copies part 2 * Decision transformer updates * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fix copies * Decision transformer not supporting flash attn * Addressed comments * Addressed comments * Addressed comments --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -60,6 +60,73 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o
|
||||
- Enabling the *scale_attn_by_inverse_layer_idx* and *reorder_and_upcast_attn* flags will apply the training stability
|
||||
improvements from [Mistral](https://github.com/stanford-crfm/mistral/) (for PyTorch only).
|
||||
|
||||
## Usage example
|
||||
|
||||
The `generate()` method can be used to generate text using GPT2 model.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
>>> prompt = "GPT2 is a model developed by OpenAI."
|
||||
|
||||
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
>>> gen_tokens = model.generate(
|
||||
... input_ids,
|
||||
... do_sample=True,
|
||||
... temperature=0.9,
|
||||
... max_length=100,
|
||||
... )
|
||||
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
|
||||
```
|
||||
|
||||
## Using Flash Attention 2
|
||||
|
||||
Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.
|
||||
|
||||
### Installation
|
||||
|
||||
First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features). If your hardware is not compatible with Flash Attention 2, you can still benefit from attention kernel optimisations through Better Transformer support covered [above](https://huggingface.co/docs/transformers/main/en/model_doc/bark#using-better-transformer).
|
||||
|
||||
Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
>>> prompt = "def hello_world():"
|
||||
|
||||
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
|
||||
>>> model.to(device)
|
||||
|
||||
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
|
||||
>>> tokenizer.batch_decode(generated_ids)[0]
|
||||
```
|
||||
|
||||
|
||||
### Expected speedups
|
||||
|
||||
Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `gpt2` checkpoint and the Flash Attention 2 version of the model using a sequence length of 512.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/EduardoPacheco/documentation-images/resolve/main/gpt2_flash_attention_2_speedup.jpg">
|
||||
</div>
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
Reference in New Issue
Block a user