[FA2] Add flash attention for for DistilBert (#26489)
* flash attention added for DistilBert * fixes * removed padding_masks * Update modeling_distilbert.py * Update test_modeling_distilbert.py * style fix
This commit is contained in:
@@ -133,6 +133,37 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
- A blog post on how to [deploy DistilBERT with Amazon SageMaker](https://huggingface.co/blog/deploy-hugging-face-models-easily-with-amazon-sagemaker).
|
||||
- A blog post on how to [Deploy BERT with Hugging Face Transformers, Amazon SageMaker and Terraform module](https://www.philschmid.de/terraform-huggingface-amazon-sagemaker).
|
||||
|
||||
|
||||
## Combining DistilBERT and Flash Attention 2
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
|
||||
|
||||
To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
|
||||
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, use_flash_attention_2=True)
|
||||
|
||||
>>> text = "Replace me by any text you'd like."
|
||||
|
||||
>>> encoded_input = tokenizer(text, return_tensors='pt').to(device)
|
||||
>>> model.to(device)
|
||||
|
||||
>>> output = model(**encoded_input)
|
||||
```
|
||||
|
||||
|
||||
## DistilBertConfig
|
||||
|
||||
[[autodoc]] DistilBertConfig
|
||||
|
||||
Reference in New Issue
Block a user