Add sdpa and fa2 the Wav2vec2 family. (#30121)
* add sdpa to wav2vec. Co-authored-by: kamilakesbi <kamil@huggingface.co> Co-authored-by: jp1924 <jp42maru@gmail.com> * add fa2 to wav2vec2 * add tests * fix attention_mask compatibility with fa2 * minor dtype fix * replace fa2 slow test * fix fa2 slow test * apply code review + add fa2 batch test * add sdpa and fa2 to hubert * sdpa and fa2 to data2vec_audio * sdpa and fa2 to Sew * sdpa to unispeech + unispeech sat * small fix * attention mask in tests Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * add_speedup_benchmark_to_doc --------- Co-authored-by: kamil@huggingface.co <kamil.akesbi@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -25,6 +25,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from pytest import mark
|
||||
|
||||
from transformers import Wav2Vec2Config, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
@@ -33,9 +34,11 @@ from transformers.testing_utils import (
|
||||
is_pt_flax_cross_test,
|
||||
is_pyctcdecode_available,
|
||||
is_torchaudio_available,
|
||||
require_flash_attn,
|
||||
require_pyctcdecode,
|
||||
require_soundfile,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torchaudio,
|
||||
run_test_in_subprocess,
|
||||
slow,
|
||||
@@ -1995,3 +1998,52 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
for lang in LANG_MAP.keys():
|
||||
assert run_model(lang) == TRANSCRIPTIONS[lang]
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
def test_inference_ctc_fa2(self):
|
||||
model_fa = Wav2Vec2ForCTC.from_pretrained(
|
||||
"facebook/wav2vec2-base-960h", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
input_speech = self._load_datasamples(1)
|
||||
|
||||
input_values = processor(input_speech, return_tensors="pt").input_values.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model_fa(input_values.to(torch.bfloat16)).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
def test_inference_ctc_fa2_batched(self):
|
||||
model_fa = Wav2Vec2ForCTC.from_pretrained(
|
||||
"facebook/wav2vec2-base-960h", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True, return_attention_mask=True)
|
||||
inputs = inputs.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model_fa(inputs.input_values.to(torch.bfloat16), attention_mask=inputs.attention_mask).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
"sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
Reference in New Issue
Block a user