Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode (#18351)
* [Wav2Vec2] Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode * [Wav2Vec2] Add user-managed LM's pool tests and usage examples * Improve styling Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * [Wav2Vec2] Fix hyperlink references Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bf0e094142
commit
af150e4a1c
@@ -14,6 +14,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import multiprocessing
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -21,6 +22,7 @@ from datasets import load_dataset
|
||||
|
||||
from transformers import Wav2Vec2Config, is_flax_available
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_flaky,
|
||||
is_librosa_available,
|
||||
is_pt_flax_cross_test,
|
||||
@@ -53,6 +55,7 @@ if is_flax_available():
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
from transformers import Wav2Vec2ProcessorWithLM
|
||||
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
@@ -554,3 +557,58 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
transcription = processor.batch_decode(np.array(logits)).text
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_librosa
|
||||
def test_wav2vec2_with_lm_pool(self):
|
||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
|
||||
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
|
||||
|
||||
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(resampled_audio, return_tensors="np").input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
# test user-managed pool
|
||||
with multiprocessing.get_context("fork").Pool(2) as pool:
|
||||
transcription = processor.batch_decode(logits.numpy(), pool).text
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
# user-managed pool + num_processes should trigger a warning
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
|
||||
2
|
||||
) as pool:
|
||||
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
|
||||
|
||||
self.assertIn("num_process", cl.out)
|
||||
self.assertIn("it will be ignored", cl.out)
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_librosa
|
||||
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
|
||||
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
|
||||
|
||||
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(resampled_audio, return_tensors="np").input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
# change default start method, which should trigger a warning if different than fork
|
||||
multiprocessing.set_start_method("spawn")
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||
transcription = processor.batch_decode(logits.numpy()).text
|
||||
|
||||
self.assertIn("Falling back to sequential decoding.", cl.out)
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
@@ -18,6 +18,7 @@ import copy
|
||||
import glob
|
||||
import inspect
|
||||
import math
|
||||
import multiprocessing
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -26,7 +27,7 @@ from datasets import load_dataset
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import Wav2Vec2Config, is_tf_available
|
||||
from transformers.testing_utils import is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
|
||||
from transformers.testing_utils import CaptureLogger, is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
|
||||
from transformers.utils import is_librosa_available, is_pyctcdecode_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -42,6 +43,7 @@ if is_tf_available():
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
from transformers import Wav2Vec2ProcessorWithLM
|
||||
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
@@ -590,3 +592,56 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
transcription = processor.batch_decode(logits.numpy()).text
|
||||
|
||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_librosa
|
||||
def test_wav2vec2_with_lm_pool(self):
|
||||
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
|
||||
file_path = glob.glob(downloaded_folder + "/*")[0]
|
||||
sample = librosa.load(file_path, sr=16_000)[0]
|
||||
|
||||
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(sample, return_tensors="tf").input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
# test user-managed pool
|
||||
with multiprocessing.get_context("fork").Pool(2) as pool:
|
||||
transcription = processor.batch_decode(logits.numpy(), pool).text
|
||||
|
||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
|
||||
# user-managed pool + num_processes should trigger a warning
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
|
||||
2
|
||||
) as pool:
|
||||
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
|
||||
|
||||
self.assertIn("num_process", cl.out)
|
||||
self.assertIn("it will be ignored", cl.out)
|
||||
|
||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_librosa
|
||||
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
|
||||
file_path = glob.glob(downloaded_folder + "/*")[0]
|
||||
sample = librosa.load(file_path, sr=16_000)[0]
|
||||
|
||||
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(sample, return_tensors="tf").input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
# change default start method, which should trigger a warning if different than fork
|
||||
multiprocessing.set_start_method("spawn")
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||
transcription = processor.batch_decode(logits.numpy()).text
|
||||
|
||||
self.assertIn("Falling back to sequential decoding.", cl.out)
|
||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
""" Testing suite for the PyTorch Wav2Vec2 model. """
|
||||
|
||||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
@@ -25,6 +26,7 @@ from datasets import load_dataset
|
||||
|
||||
from transformers import Wav2Vec2Config, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_pt_flax_cross_test,
|
||||
is_pyctcdecode_available,
|
||||
is_torchaudio_available,
|
||||
@@ -74,6 +76,7 @@ if is_torchaudio_available():
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
from transformers import Wav2Vec2ProcessorWithLM
|
||||
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
||||
|
||||
|
||||
if is_torch_fx_available():
|
||||
@@ -1611,6 +1614,71 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_torchaudio
|
||||
def test_wav2vec2_with_lm_pool(self):
|
||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
|
||||
resampled_audio = torchaudio.functional.resample(
|
||||
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
|
||||
).numpy()
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
|
||||
torch_device
|
||||
)
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(resampled_audio, return_tensors="pt").input_values
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values.to(torch_device)).logits
|
||||
|
||||
# test user-managed pool
|
||||
with multiprocessing.get_context("fork").Pool(2) as pool:
|
||||
transcription = processor.batch_decode(logits.numpy(), pool).text
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
# user-managed pool + num_processes should trigger a warning
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
|
||||
2
|
||||
) as pool:
|
||||
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
|
||||
|
||||
self.assertIn("num_process", cl.out)
|
||||
self.assertIn("it will be ignored", cl.out)
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_torchaudio
|
||||
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
|
||||
resampled_audio = torchaudio.functional.resample(
|
||||
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
|
||||
).numpy()
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
|
||||
torch_device
|
||||
)
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(resampled_audio, return_tensors="pt").input_values
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values.to(torch_device)).logits
|
||||
|
||||
# change default start method, which should trigger a warning if different than fork
|
||||
multiprocessing.set_start_method("spawn")
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||
transcription = processor.batch_decode(logits.numpy()).text
|
||||
|
||||
self.assertIn("Falling back to sequential decoding.", cl.out)
|
||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
|
||||
def test_inference_diarization(self):
|
||||
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd")
|
||||
|
||||
@@ -25,6 +25,7 @@ import numpy as np
|
||||
from datasets import load_dataset
|
||||
from packaging import version
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoProcessor
|
||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
@@ -194,7 +195,8 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
|
||||
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
|
||||
|
||||
def test_decoder_batch(self):
|
||||
@parameterized.expand([[None], ["fork"], ["spawn"]])
|
||||
def test_decoder_batch(self, pool_context):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
decoder = self.get_decoder()
|
||||
@@ -203,17 +205,25 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
|
||||
logits = self._get_dummy_logits()
|
||||
|
||||
decoded_processor = processor.batch_decode(logits)
|
||||
# note: pool should be instantiated *after* Wav2Vec2ProcessorWithLM.
|
||||
# otherwise, the LM won't be available to the pool's sub-processes.
|
||||
# manual logic used to allow parameterized test for both pool=None and pool=Pool(...)
|
||||
if pool_context is None:
|
||||
decoded_processor = processor.batch_decode(logits)
|
||||
else:
|
||||
with get_context(pool_context).Pool() as pool:
|
||||
decoded_processor = processor.batch_decode(logits, pool)
|
||||
|
||||
logits_list = [array for array in logits]
|
||||
pool = get_context("fork").Pool()
|
||||
decoded_beams = decoder.decode_beams_batch(pool, logits_list)
|
||||
|
||||
with get_context("fork").Pool() as p:
|
||||
decoded_beams = decoder.decode_beams_batch(p, logits_list)
|
||||
|
||||
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
|
||||
for beams in decoded_beams:
|
||||
texts_decoder.append(beams[0][0])
|
||||
logit_scores_decoder.append(beams[0][-2])
|
||||
lm_scores_decoder.append(beams[0][-1])
|
||||
pool.close()
|
||||
|
||||
self.assertListEqual(texts_decoder, decoded_processor.text)
|
||||
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text)
|
||||
@@ -242,15 +252,15 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
decoded_processor = decoded_processor_out.text
|
||||
|
||||
logits_list = [array for array in logits]
|
||||
pool = get_context("fork").Pool()
|
||||
decoded_decoder_out = decoder.decode_beams_batch(
|
||||
pool,
|
||||
logits_list,
|
||||
beam_width=beam_width,
|
||||
beam_prune_logp=beam_prune_logp,
|
||||
token_min_logp=token_min_logp,
|
||||
)
|
||||
pool.close()
|
||||
|
||||
with get_context("fork").Pool() as pool:
|
||||
decoded_decoder_out = decoder.decode_beams_batch(
|
||||
pool,
|
||||
logits_list,
|
||||
beam_width=beam_width,
|
||||
beam_prune_logp=beam_prune_logp,
|
||||
token_min_logp=token_min_logp,
|
||||
)
|
||||
|
||||
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
|
||||
|
||||
@@ -287,12 +297,12 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
unk_score_offset=unk_score_offset,
|
||||
lm_score_boundary=lm_score_boundary,
|
||||
)
|
||||
pool = get_context("fork").Pool()
|
||||
decoded_decoder_out = decoder.decode_beams_batch(
|
||||
pool,
|
||||
logits_list,
|
||||
)
|
||||
pool.close()
|
||||
|
||||
with get_context("fork").Pool() as pool:
|
||||
decoded_decoder_out = decoder.decode_beams_batch(
|
||||
pool,
|
||||
logits_list,
|
||||
)
|
||||
|
||||
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user