[FA2] Add flash attention for for DistilBert (#26489)

* flash attention added for DistilBert

* fixes

* removed padding_masks

* Update modeling_distilbert.py

* Update test_modeling_distilbert.py

* style fix
This commit is contained in:
Susnato Dhar
2023-11-03 21:37:54 +05:30
committed by GitHub
parent 5964f820db
commit 1ac2463dfe
3 changed files with 348 additions and 5 deletions

View File

@@ -133,6 +133,37 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
- A blog post on how to [deploy DistilBERT with Amazon SageMaker](https://huggingface.co/blog/deploy-hugging-face-models-easily-with-amazon-sagemaker).
- A blog post on how to [Deploy BERT with Hugging Face Transformers, Amazon SageMaker and Terraform module](https://www.philschmid.de/terraform-huggingface-amazon-sagemaker).
## Combining DistilBERT and Flash Attention 2
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
```bash
pip install -U flash-attn --no-build-isolation
```
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
To load and run a model using Flash Attention 2, refer to the snippet below:
```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModel
>>> device = "cuda" # the device to load the model onto
>>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> text = "Replace me by any text you'd like."
>>> encoded_input = tokenizer(text, return_tensors='pt').to(device)
>>> model.to(device)
>>> output = model(**encoded_input)
```
## DistilBertConfig
[[autodoc]] DistilBertConfig