Add Flash Attention 2 to M2M100 model (#30256)
* Added flash attention 2. * Fixes. * Fix inheritance. * Fixed init. * Remove stuff. * Added documentation. * Add FA2 to M2M100 documentation. * Add test. * Fixed documentation. * Update src/transformers/models/m2m_100/modeling_m2m_100.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update docs/source/en/model_doc/nllb.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fixed variable name. --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
ec92f983af
commit
b65df514d1
@@ -121,3 +121,45 @@ Hindi to French and Chinese to English using the *facebook/m2m100_418M* checkpoi
|
||||
|
||||
[[autodoc]] M2M100ForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## Using Flash Attention 2
|
||||
|
||||
Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.
|
||||
|
||||
### Installation
|
||||
|
||||
First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features).
|
||||
|
||||
Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). You can use either `torch.float16` or `torch.bfloat16` precision.
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
||||
|
||||
>>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda").eval()
|
||||
>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
|
||||
|
||||
>>> # translate Hindi to French
|
||||
>>> hi_text = "जीवन एक चॉकलेट बॉक्स की तरह है।"
|
||||
>>> tokenizer.src_lang = "hi"
|
||||
>>> encoded_hi = tokenizer(hi_text, return_tensors="pt").to("cuda")
|
||||
>>> generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.get_lang_id("fr"))
|
||||
>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
||||
"La vie est comme une boîte de chocolat."
|
||||
```
|
||||
|
||||
### Expected speedups
|
||||
|
||||
Below is an expected speedup diagram that compares pure inference time between the native implementation and the Flash Attention 2.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/visheratin/documentation-images/resolve/main/nllb-speedup.webp">
|
||||
</div>
|
||||
|
||||
@@ -145,3 +145,46 @@ UN-Chef sagt, es gibt keine militärische Lösung in Syrien
|
||||
## NllbTokenizerFast
|
||||
|
||||
[[autodoc]] NllbTokenizerFast
|
||||
|
||||
## Using Flash Attention 2
|
||||
|
||||
Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.
|
||||
|
||||
### Installation
|
||||
|
||||
First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features).
|
||||
|
||||
Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). You can use either `torch.float16` or `torch.bfloat16` precision.
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
>>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda").eval()
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
||||
|
||||
>>> article = "Şeful ONU spune că nu există o soluţie militară în Siria"
|
||||
>>> inputs = tokenizer(article, return_tensors="pt").to("cuda")
|
||||
|
||||
>>> translated_tokens = model.generate(
|
||||
... **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["deu_Latn"], max_length=30
|
||||
... )
|
||||
>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||
"UN-Chef sagt, es gibt keine militärische Lösung in Syrien"
|
||||
```
|
||||
|
||||
### Expected speedups
|
||||
|
||||
Below is an expected speedup diagram that compares pure inference time between the native implementation and the Flash Attention 2.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/visheratin/documentation-images/resolve/main/nllb-speedup.webp">
|
||||
</div>
|
||||
Reference in New Issue
Block a user