Add sdpa and fa2 the Wav2vec2 family. (#30121)

* add sdpa to wav2vec.
Co-authored-by: kamilakesbi <kamil@huggingface.co>
Co-authored-by: jp1924 <jp42maru@gmail.com>

* add fa2 to wav2vec2

* add tests

* fix attention_mask compatibility with fa2

* minor dtype fix

* replace fa2 slow test

* fix fa2 slow test

* apply code review + add fa2 batch test

* add sdpa and fa2 to hubert

* sdpa and fa2 to data2vec_audio

* sdpa and fa2 to Sew

* sdpa to unispeech + unispeech sat

* small fix

* attention mask in tests

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* add_speedup_benchmark_to_doc

---------

Co-authored-by: kamil@huggingface.co <kamil.akesbi@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Kamil Akesbi
2024-04-22 19:30:38 +02:00
committed by GitHub
parent 367a0dbd53
commit 569743f510
11 changed files with 2406 additions and 100 deletions

View File

@@ -25,6 +25,7 @@ import unittest
import numpy as np
from datasets import load_dataset
from pytest import mark
from transformers import Wav2Vec2Config, is_torch_available
from transformers.testing_utils import (
@@ -33,9 +34,11 @@ from transformers.testing_utils import (
is_pt_flax_cross_test,
is_pyctcdecode_available,
is_torchaudio_available,
require_flash_attn,
require_pyctcdecode,
require_soundfile,
require_torch,
require_torch_gpu,
require_torchaudio,
run_test_in_subprocess,
slow,
@@ -1995,3 +1998,52 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
for lang in LANG_MAP.keys():
assert run_model(lang) == TRANSCRIPTIONS[lang]
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_inference_ctc_fa2(self):
model_fa = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-base-960h", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
)
model_fa.to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
input_speech = self._load_datasamples(1)
input_values = processor(input_speech, return_tensors="pt").input_values.to(torch_device)
with torch.no_grad():
logits = model_fa(input_values.to(torch.bfloat16)).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_inference_ctc_fa2_batched(self):
model_fa = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-base-960h", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
)
model_fa.to(torch_device)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
input_speech = self._load_datasamples(2)
inputs = processor(input_speech, return_tensors="pt", padding=True, return_attention_mask=True)
inputs = inputs.to(torch_device)
with torch.no_grad():
logits = model_fa(inputs.input_values.to(torch.bfloat16), attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)