diff --git a/docs/source/en/model_doc/wav2vec2.mdx b/docs/source/en/model_doc/wav2vec2.mdx
index eaca36be46..79870f5623 100644
--- a/docs/source/en/model_doc/wav2vec2.mdx
+++ b/docs/source/en/model_doc/wav2vec2.mdx
@@ -73,6 +73,61 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
- batch_decode
- decode
+### Decoding multiple audios
+
+If you are planning to decode multiple batches of audios, you should consider using [`~Wav2Vec2ProcessorWithLM.batch_decode`] and passing an instantiated `multiprocessing.Pool`.
+Otherwise, [`~Wav2Vec2ProcessorWithLM.batch_decode`] performance will be slower than calling [`~Wav2Vec2ProcessorWithLM.decode`] for each audio individually, as it internally instantiates a new `Pool` for every call. See the example below:
+
+```python
+>>> # Let's see how to use a user-managed pool for batch decoding multiple audios
+>>> from multiprocessing import get_context
+>>> from transformers import AutoTokenizer, AutoProcessor, AutoModelForCTC
+>>> from datasets import load_dataset
+>>> import datasets
+>>> import torch
+
+>>> # import model, feature extractor, tokenizer
+>>> model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm").to("cuda")
+>>> processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
+
+>>> # load example dataset
+>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+>>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
+
+
+>>> def map_to_array(batch):
+... batch["speech"] = batch["audio"]["array"]
+... return batch
+
+
+>>> # prepare speech data for batch inference
+>>> dataset = dataset.map(map_to_array, remove_columns=["audio"])
+
+
+>>> def map_to_pred(batch, pool):
+... inputs = processor(batch["speech"], sampling_rate=16_000, padding=True, return_tensors="pt")
+... inputs = {k: v.to("cuda") for k, v in inputs.items()}
+
+... with torch.no_grad():
+... logits = model(**inputs).logits
+
+... transcription = processor.batch_decode(logits.cpu().numpy(), pool).text
+... batch["transcription"] = transcription
+... return batch
+
+
+>>> # note: pool should be instantiated *after* `Wav2Vec2ProcessorWithLM`.
+>>> # otherwise, the LM won't be available to the pool's sub-processes
+>>> # select number of processes and batch_size based on number of CPU cores available and on dataset size
+>>> with get_context("fork").Pool(processes=2) as pool:
+... result = dataset.map(
+... map_to_pred, batched=True, batch_size=2, fn_kwargs={"pool": pool}, remove_columns=["speech"]
+... )
+
+>>> result["transcription"][:2]
+['MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL', "NOR IS MISTER COULTER'S MANNER LESS INTERESTING THAN HIS MATTER"]
+```
+
## Wav2Vec2 specific outputs
[[autodoc]] models.wav2vec2_with_lm.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
diff --git a/setup.py b/setup.py
index 13f8b42d97..cb5c50b058 100644
--- a/setup.py
+++ b/setup.py
@@ -164,7 +164,7 @@ _deps = [
"tokenizers>=0.11.1,!=0.11.3,<0.14",
"torch>=1.7,!=1.12.0",
"torchaudio",
- "pyctcdecode>=0.3.0",
+ "pyctcdecode>=0.4.0",
"tqdm>=4.27",
"unidic>=1.0.2",
"unidic_lite>=1.0.7",
diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py
index 48a803fcdf..e55219b796 100644
--- a/src/transformers/dependency_versions_table.py
+++ b/src/transformers/dependency_versions_table.py
@@ -70,7 +70,7 @@ deps = {
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.14",
"torch": "torch>=1.7,!=1.12.0",
"torchaudio": "torchaudio",
- "pyctcdecode": "pyctcdecode>=0.3.0",
+ "pyctcdecode": "pyctcdecode>=0.4.0",
"tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",
diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py
index 1e77959400..8d8406817d 100644
--- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py
@@ -442,9 +442,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
- Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
- understand how to make use of `output_word_offsets`.
- [`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output.
+ Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
+ use of `output_char_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
+ output.
@@ -454,9 +454,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
- Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
- understand how to make use of `output_word_offsets`.
- [`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output.
+ Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
+ use of `output_word_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
+ output.
@@ -515,8 +515,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
- Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
- understand how to make use of `output_word_offsets`.
+ Please take a look at the example below to better understand how to make use of `output_char_offsets`.
@@ -526,8 +525,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
- Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
- understand how to make use of `output_word_offsets`.
+ Please take a look at the example below to better understand how to make use of `output_word_offsets`.
diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
index f09b5eb922..5c93b68507 100644
--- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
+++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
@@ -17,15 +17,18 @@ Speech processor class for Wav2Vec2
"""
import os
import warnings
-from contextlib import contextmanager
+from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
-from multiprocessing import get_context
+from multiprocessing import Pool, get_context, get_start_method
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union
import numpy as np
from ...processing_utils import ProcessorMixin
-from ...utils import ModelOutput, requires_backends
+from ...utils import ModelOutput, logging, requires_backends
+
+
+logger = logging.get_logger(__name__)
if TYPE_CHECKING:
@@ -115,7 +118,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
This class method is simply calling Wav2Vec2FeatureExtractor's
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], Wav2Vec2CTCTokenizer's
- [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`], and
+ [`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], and
[`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`].
Please refer to the docstrings of the methods above for more information.
@@ -280,6 +283,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
def batch_decode(
self,
logits: np.ndarray,
+ pool: Optional[Pool] = None,
num_processes: Optional[int] = None,
beam_width: Optional[int] = None,
beam_prune_logp: Optional[float] = None,
@@ -297,16 +301,32 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
- This function makes use of Python's multiprocessing.
+ This function makes use of Python's multiprocessing. Currently, multiprocessing is available only on Unix
+ systems (see this [issue](https://github.com/kensho-technologies/pyctcdecode/issues/65)).
+
+ If you are decoding multiple batches, consider creating a `Pool` and passing it to `batch_decode`. Otherwise,
+ `batch_decode` will be very slow since it will create a fresh `Pool` for each call. See usage example below.
Args:
logits (`np.ndarray`):
The logits output vector of the model representing the log probabilities for each token.
+ pool (`multiprocessing.Pool`, *optional*):
+ An optional user-managed pool. If not set, one will be automatically created and closed. The pool
+ should be instantiated *after* `Wav2Vec2ProcessorWithLM`. Otherwise, the LM won't be available to the
+ pool's sub-processes.
+
+
+
+ Currently, only pools created with a 'fork' context can be used. If a 'spawn' pool is passed, it will
+ be ignored and sequential decoding will be used instead.
+
+
+
num_processes (`int`, *optional*):
- Number of processes on which the function should be parallelized over. Defaults to the number of
- available CPUs.
+ If `pool` is not set, number of processes on which the function should be parallelized over. Defaults
+ to the number of available CPUs.
beam_width (`int`, *optional*):
Maximum number of beams at each step in decoding. Defaults to pyctcdecode's DEFAULT_BEAM_WIDTH.
beam_prune_logp (`int`, *optional*):
@@ -332,17 +352,19 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
- Please take a look at the Example of [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
- better understand how to make use of `output_word_offsets`.
- [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] works the same way with batched
- output.
+ Please take a look at the Example of [`~Wav2Vec2ProcessorWithLM.decode`] to better understand how to
+ make use of `output_word_offsets`. [`~Wav2Vec2ProcessorWithLM.batch_decode`] works the same way with
+ batched output.
Returns:
- [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
+ [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`].
+ Example:
+ See [Decoding multiple audios](#decoding-multiple-audios).
"""
+
from pyctcdecode.constants import (
DEFAULT_BEAM_WIDTH,
DEFAULT_HOTWORD_WEIGHT,
@@ -364,21 +386,41 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
# create multiprocessing pool and list numpy arrays
# filter out logits padding
logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits]
- pool = get_context("fork").Pool(num_processes)
+
+ # create a pool if necessary while also using it as a context manager to close itself
+ if pool is None:
+ # fork is safe to use only on Unix, see "Contexts and start methods" section on
+ # multiprocessing's docs (https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)
+ default_context = get_start_method()
+
+ if default_context == "fork":
+ cm = pool = get_context().Pool(num_processes)
+ else:
+ logger.warning(
+ "Parallel batch decoding is not currently supported in this platform. "
+ "Falling back to sequential decoding."
+ )
+ cm = nullcontext()
+ else:
+ # pool is managed by the user, so we don't need to close it
+ cm = nullcontext()
+
+ if num_processes is not None:
+ logger.warning(
+ "Parameter `num_process` was passed, but it will be ignored since `pool` was also specified."
+ )
# pyctcdecode
- decoded_beams = self.decoder.decode_beams_batch(
- pool,
- logits_list=logits_list,
- beam_width=beam_width,
- beam_prune_logp=beam_prune_logp,
- token_min_logp=token_min_logp,
- hotwords=hotwords,
- hotword_weight=hotword_weight,
- )
-
- # clone multi-processing pool
- pool.close()
+ with cm:
+ decoded_beams = self.decoder.decode_beams_batch(
+ pool=pool,
+ logits_list=logits_list,
+ beam_width=beam_width,
+ beam_prune_logp=beam_prune_logp,
+ token_min_logp=token_min_logp,
+ hotwords=hotwords,
+ hotword_weight=hotword_weight,
+ )
# extract text and scores
batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], []
@@ -440,13 +482,12 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
- Please take a look at the example of [`~models.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
- better understand how to make use of `output_word_offsets`.
+ Please take a look at the example below to better understand how to make use of `output_word_offsets`.
Returns:
- [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
+ [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`].
Example:
diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py
index 3915c3f8a5..569e9975b1 100644
--- a/src/transformers/processing_utils.py
+++ b/src/transformers/processing_utils.py
@@ -99,8 +99,8 @@ class ProcessorMixin(PushToHubMixin):
This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and
- [`~tokenization_utils_base.PreTrainedTokenizer.save_pretrained`]. Please refer to the docstrings of the methods
- above for more information.
+ [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`]. Please refer to the docstrings of the
+ methods above for more information.
diff --git a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py
index aa6781a42e..b5524a4f69 100644
--- a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py
+++ b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py
@@ -14,6 +14,7 @@
import inspect
import math
+import multiprocessing
import unittest
import numpy as np
@@ -21,6 +22,7 @@ from datasets import load_dataset
from transformers import Wav2Vec2Config, is_flax_available
from transformers.testing_utils import (
+ CaptureLogger,
is_flaky,
is_librosa_available,
is_pt_flax_cross_test,
@@ -53,6 +55,7 @@ if is_flax_available():
if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM
+ from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
if is_librosa_available():
@@ -554,3 +557,58 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
transcription = processor.batch_decode(np.array(logits)).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
+
+ @require_pyctcdecode
+ @require_librosa
+ def test_wav2vec2_with_lm_pool(self):
+ ds = load_dataset("common_voice", "es", split="test", streaming=True)
+ sample = next(iter(ds))
+
+ resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
+
+ model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
+
+ input_values = processor(resampled_audio, return_tensors="np").input_values
+
+ logits = model(input_values).logits
+
+ # test user-managed pool
+ with multiprocessing.get_context("fork").Pool(2) as pool:
+ transcription = processor.batch_decode(logits.numpy(), pool).text
+
+ self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
+
+ # user-managed pool + num_processes should trigger a warning
+ with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
+ 2
+ ) as pool:
+ transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
+
+ self.assertIn("num_process", cl.out)
+ self.assertIn("it will be ignored", cl.out)
+
+ self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
+
+ @require_pyctcdecode
+ @require_librosa
+ def test_wav2vec2_with_lm_invalid_pool(self):
+ ds = load_dataset("common_voice", "es", split="test", streaming=True)
+ sample = next(iter(ds))
+
+ resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
+
+ model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
+
+ input_values = processor(resampled_audio, return_tensors="np").input_values
+
+ logits = model(input_values).logits
+
+ # change default start method, which should trigger a warning if different than fork
+ multiprocessing.set_start_method("spawn")
+ with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
+ transcription = processor.batch_decode(logits.numpy()).text
+
+ self.assertIn("Falling back to sequential decoding.", cl.out)
+ self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
index 6ea1919a33..5c2a25f413 100644
--- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
+++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
@@ -18,6 +18,7 @@ import copy
import glob
import inspect
import math
+import multiprocessing
import unittest
import numpy as np
@@ -26,7 +27,7 @@ from datasets import load_dataset
from huggingface_hub import snapshot_download
from transformers import Wav2Vec2Config, is_tf_available
-from transformers.testing_utils import is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
+from transformers.testing_utils import CaptureLogger, is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
from transformers.utils import is_librosa_available, is_pyctcdecode_available
from ...test_configuration_common import ConfigTester
@@ -42,6 +43,7 @@ if is_tf_available():
if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM
+ from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
if is_librosa_available():
@@ -590,3 +592,56 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
transcription = processor.batch_decode(logits.numpy()).text
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
+
+ @require_pyctcdecode
+ @require_librosa
+ def test_wav2vec2_with_lm_pool(self):
+ downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
+ file_path = glob.glob(downloaded_folder + "/*")[0]
+ sample = librosa.load(file_path, sr=16_000)[0]
+
+ model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
+
+ input_values = processor(sample, return_tensors="tf").input_values
+
+ logits = model(input_values).logits
+
+ # test user-managed pool
+ with multiprocessing.get_context("fork").Pool(2) as pool:
+ transcription = processor.batch_decode(logits.numpy(), pool).text
+
+ self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
+
+ # user-managed pool + num_processes should trigger a warning
+ with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
+ 2
+ ) as pool:
+ transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
+
+ self.assertIn("num_process", cl.out)
+ self.assertIn("it will be ignored", cl.out)
+
+ self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
+
+ @require_pyctcdecode
+ @require_librosa
+ def test_wav2vec2_with_lm_invalid_pool(self):
+ downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
+ file_path = glob.glob(downloaded_folder + "/*")[0]
+ sample = librosa.load(file_path, sr=16_000)[0]
+
+ model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
+
+ input_values = processor(sample, return_tensors="tf").input_values
+
+ logits = model(input_values).logits
+
+ # change default start method, which should trigger a warning if different than fork
+ multiprocessing.set_start_method("spawn")
+ with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
+ transcription = processor.batch_decode(logits.numpy()).text
+
+ self.assertIn("Falling back to sequential decoding.", cl.out)
+ self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py
index 040731472f..4d4fa2981d 100644
--- a/tests/models/wav2vec2/test_modeling_wav2vec2.py
+++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py
@@ -15,6 +15,7 @@
""" Testing suite for the PyTorch Wav2Vec2 model. """
import math
+import multiprocessing
import os
import pickle
import tempfile
@@ -25,6 +26,7 @@ from datasets import load_dataset
from transformers import Wav2Vec2Config, is_torch_available
from transformers.testing_utils import (
+ CaptureLogger,
is_pt_flax_cross_test,
is_pyctcdecode_available,
is_torchaudio_available,
@@ -74,6 +76,7 @@ if is_torchaudio_available():
if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM
+ from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
if is_torch_fx_available():
@@ -1611,6 +1614,71 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
+ @require_pyctcdecode
+ @require_torchaudio
+ def test_wav2vec2_with_lm_pool(self):
+ ds = load_dataset("common_voice", "es", split="test", streaming=True)
+ sample = next(iter(ds))
+
+ resampled_audio = torchaudio.functional.resample(
+ torch.tensor(sample["audio"]["array"]), 48_000, 16_000
+ ).numpy()
+
+ model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
+ torch_device
+ )
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
+
+ input_values = processor(resampled_audio, return_tensors="pt").input_values
+
+ with torch.no_grad():
+ logits = model(input_values.to(torch_device)).logits
+
+ # test user-managed pool
+ with multiprocessing.get_context("fork").Pool(2) as pool:
+ transcription = processor.batch_decode(logits.numpy(), pool).text
+
+ self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
+
+ # user-managed pool + num_processes should trigger a warning
+ with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
+ 2
+ ) as pool:
+ transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
+
+ self.assertIn("num_process", cl.out)
+ self.assertIn("it will be ignored", cl.out)
+
+ self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
+
+ @require_pyctcdecode
+ @require_torchaudio
+ def test_wav2vec2_with_lm_invalid_pool(self):
+ ds = load_dataset("common_voice", "es", split="test", streaming=True)
+ sample = next(iter(ds))
+
+ resampled_audio = torchaudio.functional.resample(
+ torch.tensor(sample["audio"]["array"]), 48_000, 16_000
+ ).numpy()
+
+ model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
+ torch_device
+ )
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
+
+ input_values = processor(resampled_audio, return_tensors="pt").input_values
+
+ with torch.no_grad():
+ logits = model(input_values.to(torch_device)).logits
+
+ # change default start method, which should trigger a warning if different than fork
+ multiprocessing.set_start_method("spawn")
+ with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
+ transcription = processor.batch_decode(logits.numpy()).text
+
+ self.assertIn("Falling back to sequential decoding.", cl.out)
+ self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
+
def test_inference_diarization(self):
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd")
diff --git a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
index 6bf52d3e1b..11a45b6e1a 100644
--- a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
+++ b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
@@ -25,6 +25,7 @@ import numpy as np
from datasets import load_dataset
from packaging import version
+from parameterized import parameterized
from transformers import AutoProcessor
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
@@ -194,7 +195,8 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
- def test_decoder_batch(self):
+ @parameterized.expand([[None], ["fork"], ["spawn"]])
+ def test_decoder_batch(self, pool_context):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
decoder = self.get_decoder()
@@ -203,17 +205,25 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits = self._get_dummy_logits()
- decoded_processor = processor.batch_decode(logits)
+ # note: pool should be instantiated *after* Wav2Vec2ProcessorWithLM.
+ # otherwise, the LM won't be available to the pool's sub-processes.
+ # manual logic used to allow parameterized test for both pool=None and pool=Pool(...)
+ if pool_context is None:
+ decoded_processor = processor.batch_decode(logits)
+ else:
+ with get_context(pool_context).Pool() as pool:
+ decoded_processor = processor.batch_decode(logits, pool)
logits_list = [array for array in logits]
- pool = get_context("fork").Pool()
- decoded_beams = decoder.decode_beams_batch(pool, logits_list)
+
+ with get_context("fork").Pool() as p:
+ decoded_beams = decoder.decode_beams_batch(p, logits_list)
+
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
for beams in decoded_beams:
texts_decoder.append(beams[0][0])
logit_scores_decoder.append(beams[0][-2])
lm_scores_decoder.append(beams[0][-1])
- pool.close()
self.assertListEqual(texts_decoder, decoded_processor.text)
self.assertListEqual([" ", " "], decoded_processor.text)
@@ -242,15 +252,15 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_processor = decoded_processor_out.text
logits_list = [array for array in logits]
- pool = get_context("fork").Pool()
- decoded_decoder_out = decoder.decode_beams_batch(
- pool,
- logits_list,
- beam_width=beam_width,
- beam_prune_logp=beam_prune_logp,
- token_min_logp=token_min_logp,
- )
- pool.close()
+
+ with get_context("fork").Pool() as pool:
+ decoded_decoder_out = decoder.decode_beams_batch(
+ pool,
+ logits_list,
+ beam_width=beam_width,
+ beam_prune_logp=beam_prune_logp,
+ token_min_logp=token_min_logp,
+ )
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
@@ -287,12 +297,12 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
unk_score_offset=unk_score_offset,
lm_score_boundary=lm_score_boundary,
)
- pool = get_context("fork").Pool()
- decoded_decoder_out = decoder.decode_beams_batch(
- pool,
- logits_list,
- )
- pool.close()
+
+ with get_context("fork").Pool() as pool:
+ decoded_decoder_out = decoder.decode_beams_batch(
+ pool,
+ logits_list,
+ )
decoded_decoder = [d[0][0] for d in decoded_decoder_out]