Add sdpa and fa2 the Wav2vec2 family. (#30121)

* add sdpa to wav2vec.
Co-authored-by: kamilakesbi <kamil@huggingface.co>
Co-authored-by: jp1924 <jp42maru@gmail.com>

* add fa2 to wav2vec2

* add tests

* fix attention_mask compatibility with fa2

* minor dtype fix

* replace fa2 slow test

* fix fa2 slow test

* apply code review + add fa2 batch test

* add sdpa and fa2 to hubert

* sdpa and fa2 to data2vec_audio

* sdpa and fa2 to Sew

* sdpa to unispeech + unispeech sat

* small fix

* attention mask in tests

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* add_speedup_benchmark_to_doc

---------

Co-authored-by: kamil@huggingface.co <kamil.akesbi@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Kamil Akesbi
2024-04-22 19:30:38 +02:00
committed by GitHub
parent 367a0dbd53
commit 569743f510
11 changed files with 2406 additions and 100 deletions

View File

@@ -44,6 +44,42 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
- Hubert model was fine-tuned using connectionist temporal classification (CTC) so the model output has to be decoded
using [`Wav2Vec2CTCTokenizer`].
## Using Flash Attention 2
Flash Attention 2 is an faster, optimized version of the model.
### Installation
First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features). If your hardware is not compatible with Flash Attention 2, you can still benefit from attention kernel optimisations through Better Transformer support covered [above](https://huggingface.co/docs/transformers/main/en/model_doc/bark#using-better-transformer).
Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
```bash
pip install -U flash-attn --no-build-isolation
```
### Usage
Below is an expected speedup diagram comparing the pure inference time between the native implementation in transformers of `facebook/hubert-large-ls960-ft`, the flash-attention-2 and the sdpa (scale-dot-product-attention) version. We show the average speedup obtained on the `librispeech_asr` `clean` validation split:
```python
>>> from transformers import Wav2Vec2Model
model = Wav2Vec2Model.from_pretrained("facebook/hubert-large-ls960-ft", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
...
```
### Expected speedups
Below is an expected speedup diagram comparing the pure inference time between the native implementation in transformers of the `facebook/hubert-large-ls960-ft` model and the flash-attention-2 and sdpa (scale-dot-product-attention) versions. . We show the average speedup obtained on the `librispeech_asr` `clean` validation split:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/kamilakesbi/transformers_image_doc/resolve/main/data/Hubert_speedup.png">
</div>
## Resources
- [Audio classification task guide](../tasks/audio_classification)