Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode (#18351)

* [Wav2Vec2] Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode

* [Wav2Vec2] Add user-managed LM's pool tests and usage examples

* Improve styling

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* [Wav2Vec2] Fix hyperlink references

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Antonio Carlos Falcão Petri
2022-10-18 09:48:03 -03:00
committed by GitHub
parent bf0e094142
commit af150e4a1c
10 changed files with 348 additions and 63 deletions

View File

@@ -15,6 +15,7 @@
""" Testing suite for the PyTorch Wav2Vec2 model. """
import math
import multiprocessing
import os
import pickle
import tempfile
@@ -25,6 +26,7 @@ from datasets import load_dataset
from transformers import Wav2Vec2Config, is_torch_available
from transformers.testing_utils import (
CaptureLogger,
is_pt_flax_cross_test,
is_pyctcdecode_available,
is_torchaudio_available,
@@ -74,6 +76,7 @@ if is_torchaudio_available():
if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
if is_torch_fx_available():
@@ -1611,6 +1614,71 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
@require_pyctcdecode
@require_torchaudio
def test_wav2vec2_with_lm_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = torchaudio.functional.resample(
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
).numpy()
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
torch_device
)
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values.to(torch_device)).logits
# test user-managed pool
with multiprocessing.get_context("fork").Pool(2) as pool:
transcription = processor.batch_decode(logits.numpy(), pool).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
# user-managed pool + num_processes should trigger a warning
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
2
) as pool:
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
self.assertIn("num_process", cl.out)
self.assertIn("it will be ignored", cl.out)
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
@require_pyctcdecode
@require_torchaudio
def test_wav2vec2_with_lm_invalid_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = torchaudio.functional.resample(
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
).numpy()
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
torch_device
)
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values.to(torch_device)).logits
# change default start method, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn")
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.numpy()).text
self.assertIn("Falling back to sequential decoding.", cl.out)
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
def test_inference_diarization(self):
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd")