Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode (#18351)
* [Wav2Vec2] Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode * [Wav2Vec2] Add user-managed LM's pool tests and usage examples * Improve styling Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * [Wav2Vec2] Fix hyperlink references Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bf0e094142
commit
af150e4a1c
@@ -73,6 +73,61 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
|
||||
- batch_decode
|
||||
- decode
|
||||
|
||||
### Decoding multiple audios
|
||||
|
||||
If you are planning to decode multiple batches of audios, you should consider using [`~Wav2Vec2ProcessorWithLM.batch_decode`] and passing an instantiated `multiprocessing.Pool`.
|
||||
Otherwise, [`~Wav2Vec2ProcessorWithLM.batch_decode`] performance will be slower than calling [`~Wav2Vec2ProcessorWithLM.decode`] for each audio individually, as it internally instantiates a new `Pool` for every call. See the example below:
|
||||
|
||||
```python
|
||||
>>> # Let's see how to use a user-managed pool for batch decoding multiple audios
|
||||
>>> from multiprocessing import get_context
|
||||
>>> from transformers import AutoTokenizer, AutoProcessor, AutoModelForCTC
|
||||
>>> from datasets import load_dataset
|
||||
>>> import datasets
|
||||
>>> import torch
|
||||
|
||||
>>> # import model, feature extractor, tokenizer
|
||||
>>> model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm").to("cuda")
|
||||
>>> processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
|
||||
|
||||
>>> # load example dataset
|
||||
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
... batch["speech"] = batch["audio"]["array"]
|
||||
... return batch
|
||||
|
||||
|
||||
>>> # prepare speech data for batch inference
|
||||
>>> dataset = dataset.map(map_to_array, remove_columns=["audio"])
|
||||
|
||||
|
||||
>>> def map_to_pred(batch, pool):
|
||||
... inputs = processor(batch["speech"], sampling_rate=16_000, padding=True, return_tensors="pt")
|
||||
... inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
||||
|
||||
... with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
|
||||
... transcription = processor.batch_decode(logits.cpu().numpy(), pool).text
|
||||
... batch["transcription"] = transcription
|
||||
... return batch
|
||||
|
||||
|
||||
>>> # note: pool should be instantiated *after* `Wav2Vec2ProcessorWithLM`.
|
||||
>>> # otherwise, the LM won't be available to the pool's sub-processes
|
||||
>>> # select number of processes and batch_size based on number of CPU cores available and on dataset size
|
||||
>>> with get_context("fork").Pool(processes=2) as pool:
|
||||
... result = dataset.map(
|
||||
... map_to_pred, batched=True, batch_size=2, fn_kwargs={"pool": pool}, remove_columns=["speech"]
|
||||
... )
|
||||
|
||||
>>> result["transcription"][:2]
|
||||
['MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL', "NOR IS MISTER COULTER'S MANNER LESS INTERESTING THAN HIS MATTER"]
|
||||
```
|
||||
|
||||
## Wav2Vec2 specific outputs
|
||||
|
||||
[[autodoc]] models.wav2vec2_with_lm.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
|
||||
|
||||
2
setup.py
2
setup.py
@@ -164,7 +164,7 @@ _deps = [
|
||||
"tokenizers>=0.11.1,!=0.11.3,<0.14",
|
||||
"torch>=1.7,!=1.12.0",
|
||||
"torchaudio",
|
||||
"pyctcdecode>=0.3.0",
|
||||
"pyctcdecode>=0.4.0",
|
||||
"tqdm>=4.27",
|
||||
"unidic>=1.0.2",
|
||||
"unidic_lite>=1.0.7",
|
||||
|
||||
@@ -70,7 +70,7 @@ deps = {
|
||||
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.14",
|
||||
"torch": "torch>=1.7,!=1.12.0",
|
||||
"torchaudio": "torchaudio",
|
||||
"pyctcdecode": "pyctcdecode>=0.3.0",
|
||||
"pyctcdecode": "pyctcdecode>=0.4.0",
|
||||
"tqdm": "tqdm>=4.27",
|
||||
"unidic": "unidic>=1.0.2",
|
||||
"unidic_lite": "unidic_lite>=1.0.7",
|
||||
|
||||
@@ -442,9 +442,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||
understand how to make use of `output_word_offsets`.
|
||||
[`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output.
|
||||
Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
|
||||
use of `output_char_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
|
||||
output.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -454,9 +454,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||
understand how to make use of `output_word_offsets`.
|
||||
[`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output.
|
||||
Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
|
||||
use of `output_word_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
|
||||
output.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -515,8 +515,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||
understand how to make use of `output_word_offsets`.
|
||||
Please take a look at the example below to better understand how to make use of `output_char_offsets`.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -526,8 +525,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||
understand how to make use of `output_word_offsets`.
|
||||
Please take a look at the example below to better understand how to make use of `output_word_offsets`.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
@@ -17,15 +17,18 @@ Speech processor class for Wav2Vec2
|
||||
"""
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import get_context
|
||||
from multiprocessing import Pool, get_context, get_start_method
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...utils import ModelOutput, requires_backends
|
||||
from ...utils import ModelOutput, logging, requires_backends
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -115,7 +118,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
|
||||
This class method is simply calling Wav2Vec2FeatureExtractor's
|
||||
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], Wav2Vec2CTCTokenizer's
|
||||
[`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`], and
|
||||
[`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], and
|
||||
[`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`].
|
||||
|
||||
Please refer to the docstrings of the methods above for more information.
|
||||
@@ -280,6 +283,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
def batch_decode(
|
||||
self,
|
||||
logits: np.ndarray,
|
||||
pool: Optional[Pool] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
beam_width: Optional[int] = None,
|
||||
beam_prune_logp: Optional[float] = None,
|
||||
@@ -297,16 +301,32 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
|
||||
<Tip>
|
||||
|
||||
This function makes use of Python's multiprocessing.
|
||||
This function makes use of Python's multiprocessing. Currently, multiprocessing is available only on Unix
|
||||
systems (see this [issue](https://github.com/kensho-technologies/pyctcdecode/issues/65)).
|
||||
|
||||
If you are decoding multiple batches, consider creating a `Pool` and passing it to `batch_decode`. Otherwise,
|
||||
`batch_decode` will be very slow since it will create a fresh `Pool` for each call. See usage example below.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
logits (`np.ndarray`):
|
||||
The logits output vector of the model representing the log probabilities for each token.
|
||||
pool (`multiprocessing.Pool`, *optional*):
|
||||
An optional user-managed pool. If not set, one will be automatically created and closed. The pool
|
||||
should be instantiated *after* `Wav2Vec2ProcessorWithLM`. Otherwise, the LM won't be available to the
|
||||
pool's sub-processes.
|
||||
|
||||
<Tip>
|
||||
|
||||
Currently, only pools created with a 'fork' context can be used. If a 'spawn' pool is passed, it will
|
||||
be ignored and sequential decoding will be used instead.
|
||||
|
||||
</Tip>
|
||||
|
||||
num_processes (`int`, *optional*):
|
||||
Number of processes on which the function should be parallelized over. Defaults to the number of
|
||||
available CPUs.
|
||||
If `pool` is not set, number of processes on which the function should be parallelized over. Defaults
|
||||
to the number of available CPUs.
|
||||
beam_width (`int`, *optional*):
|
||||
Maximum number of beams at each step in decoding. Defaults to pyctcdecode's DEFAULT_BEAM_WIDTH.
|
||||
beam_prune_logp (`int`, *optional*):
|
||||
@@ -332,17 +352,19 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the Example of [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
|
||||
better understand how to make use of `output_word_offsets`.
|
||||
[`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] works the same way with batched
|
||||
output.
|
||||
Please take a look at the Example of [`~Wav2Vec2ProcessorWithLM.decode`] to better understand how to
|
||||
make use of `output_word_offsets`. [`~Wav2Vec2ProcessorWithLM.batch_decode`] works the same way with
|
||||
batched output.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
||||
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`].
|
||||
|
||||
Example:
|
||||
See [Decoding multiple audios](#decoding-multiple-audios).
|
||||
"""
|
||||
|
||||
from pyctcdecode.constants import (
|
||||
DEFAULT_BEAM_WIDTH,
|
||||
DEFAULT_HOTWORD_WEIGHT,
|
||||
@@ -364,21 +386,41 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
# create multiprocessing pool and list numpy arrays
|
||||
# filter out logits padding
|
||||
logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits]
|
||||
pool = get_context("fork").Pool(num_processes)
|
||||
|
||||
# create a pool if necessary while also using it as a context manager to close itself
|
||||
if pool is None:
|
||||
# fork is safe to use only on Unix, see "Contexts and start methods" section on
|
||||
# multiprocessing's docs (https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)
|
||||
default_context = get_start_method()
|
||||
|
||||
if default_context == "fork":
|
||||
cm = pool = get_context().Pool(num_processes)
|
||||
else:
|
||||
logger.warning(
|
||||
"Parallel batch decoding is not currently supported in this platform. "
|
||||
"Falling back to sequential decoding."
|
||||
)
|
||||
cm = nullcontext()
|
||||
else:
|
||||
# pool is managed by the user, so we don't need to close it
|
||||
cm = nullcontext()
|
||||
|
||||
if num_processes is not None:
|
||||
logger.warning(
|
||||
"Parameter `num_process` was passed, but it will be ignored since `pool` was also specified."
|
||||
)
|
||||
|
||||
# pyctcdecode
|
||||
decoded_beams = self.decoder.decode_beams_batch(
|
||||
pool,
|
||||
logits_list=logits_list,
|
||||
beam_width=beam_width,
|
||||
beam_prune_logp=beam_prune_logp,
|
||||
token_min_logp=token_min_logp,
|
||||
hotwords=hotwords,
|
||||
hotword_weight=hotword_weight,
|
||||
)
|
||||
|
||||
# clone multi-processing pool
|
||||
pool.close()
|
||||
with cm:
|
||||
decoded_beams = self.decoder.decode_beams_batch(
|
||||
pool=pool,
|
||||
logits_list=logits_list,
|
||||
beam_width=beam_width,
|
||||
beam_prune_logp=beam_prune_logp,
|
||||
token_min_logp=token_min_logp,
|
||||
hotwords=hotwords,
|
||||
hotword_weight=hotword_weight,
|
||||
)
|
||||
|
||||
# extract text and scores
|
||||
batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], []
|
||||
@@ -440,13 +482,12 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the example of [`~models.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
|
||||
better understand how to make use of `output_word_offsets`.
|
||||
Please take a look at the example below to better understand how to make use of `output_word_offsets`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
||||
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`].
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
@@ -99,8 +99,8 @@ class ProcessorMixin(PushToHubMixin):
|
||||
<Tip>
|
||||
|
||||
This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and
|
||||
[`~tokenization_utils_base.PreTrainedTokenizer.save_pretrained`]. Please refer to the docstrings of the methods
|
||||
above for more information.
|
||||
[`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`]. Please refer to the docstrings of the
|
||||
methods above for more information.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import multiprocessing
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -21,6 +22,7 @@ from datasets import load_dataset
|
||||
|
||||
from transformers import Wav2Vec2Config, is_flax_available
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_flaky,
|
||||
is_librosa_available,
|
||||
is_pt_flax_cross_test,
|
||||
@@ -53,6 +55,7 @@ if is_flax_available():
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
from transformers import Wav2Vec2ProcessorWithLM
|
||||
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
@@ -554,3 +557,58 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
transcription = processor.batch_decode(np.array(logits)).text
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_librosa
|
||||
def test_wav2vec2_with_lm_pool(self):
|
||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
|
||||
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
|
||||
|
||||
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(resampled_audio, return_tensors="np").input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
# test user-managed pool
|
||||
with multiprocessing.get_context("fork").Pool(2) as pool:
|
||||
transcription = processor.batch_decode(logits.numpy(), pool).text
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
# user-managed pool + num_processes should trigger a warning
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
|
||||
2
|
||||
) as pool:
|
||||
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
|
||||
|
||||
self.assertIn("num_process", cl.out)
|
||||
self.assertIn("it will be ignored", cl.out)
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_librosa
|
||||
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
|
||||
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
|
||||
|
||||
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(resampled_audio, return_tensors="np").input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
# change default start method, which should trigger a warning if different than fork
|
||||
multiprocessing.set_start_method("spawn")
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||
transcription = processor.batch_decode(logits.numpy()).text
|
||||
|
||||
self.assertIn("Falling back to sequential decoding.", cl.out)
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
@@ -18,6 +18,7 @@ import copy
|
||||
import glob
|
||||
import inspect
|
||||
import math
|
||||
import multiprocessing
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -26,7 +27,7 @@ from datasets import load_dataset
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import Wav2Vec2Config, is_tf_available
|
||||
from transformers.testing_utils import is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
|
||||
from transformers.testing_utils import CaptureLogger, is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
|
||||
from transformers.utils import is_librosa_available, is_pyctcdecode_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -42,6 +43,7 @@ if is_tf_available():
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
from transformers import Wav2Vec2ProcessorWithLM
|
||||
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
@@ -590,3 +592,56 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
transcription = processor.batch_decode(logits.numpy()).text
|
||||
|
||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_librosa
|
||||
def test_wav2vec2_with_lm_pool(self):
|
||||
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
|
||||
file_path = glob.glob(downloaded_folder + "/*")[0]
|
||||
sample = librosa.load(file_path, sr=16_000)[0]
|
||||
|
||||
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(sample, return_tensors="tf").input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
# test user-managed pool
|
||||
with multiprocessing.get_context("fork").Pool(2) as pool:
|
||||
transcription = processor.batch_decode(logits.numpy(), pool).text
|
||||
|
||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
|
||||
# user-managed pool + num_processes should trigger a warning
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
|
||||
2
|
||||
) as pool:
|
||||
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
|
||||
|
||||
self.assertIn("num_process", cl.out)
|
||||
self.assertIn("it will be ignored", cl.out)
|
||||
|
||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_librosa
|
||||
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
|
||||
file_path = glob.glob(downloaded_folder + "/*")[0]
|
||||
sample = librosa.load(file_path, sr=16_000)[0]
|
||||
|
||||
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(sample, return_tensors="tf").input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
# change default start method, which should trigger a warning if different than fork
|
||||
multiprocessing.set_start_method("spawn")
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||
transcription = processor.batch_decode(logits.numpy()).text
|
||||
|
||||
self.assertIn("Falling back to sequential decoding.", cl.out)
|
||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
""" Testing suite for the PyTorch Wav2Vec2 model. """
|
||||
|
||||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
@@ -25,6 +26,7 @@ from datasets import load_dataset
|
||||
|
||||
from transformers import Wav2Vec2Config, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_pt_flax_cross_test,
|
||||
is_pyctcdecode_available,
|
||||
is_torchaudio_available,
|
||||
@@ -74,6 +76,7 @@ if is_torchaudio_available():
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
from transformers import Wav2Vec2ProcessorWithLM
|
||||
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
||||
|
||||
|
||||
if is_torch_fx_available():
|
||||
@@ -1611,6 +1614,71 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_torchaudio
|
||||
def test_wav2vec2_with_lm_pool(self):
|
||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
|
||||
resampled_audio = torchaudio.functional.resample(
|
||||
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
|
||||
).numpy()
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
|
||||
torch_device
|
||||
)
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(resampled_audio, return_tensors="pt").input_values
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values.to(torch_device)).logits
|
||||
|
||||
# test user-managed pool
|
||||
with multiprocessing.get_context("fork").Pool(2) as pool:
|
||||
transcription = processor.batch_decode(logits.numpy(), pool).text
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
# user-managed pool + num_processes should trigger a warning
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
|
||||
2
|
||||
) as pool:
|
||||
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
|
||||
|
||||
self.assertIn("num_process", cl.out)
|
||||
self.assertIn("it will be ignored", cl.out)
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_torchaudio
|
||||
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
|
||||
resampled_audio = torchaudio.functional.resample(
|
||||
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
|
||||
).numpy()
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
|
||||
torch_device
|
||||
)
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(resampled_audio, return_tensors="pt").input_values
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values.to(torch_device)).logits
|
||||
|
||||
# change default start method, which should trigger a warning if different than fork
|
||||
multiprocessing.set_start_method("spawn")
|
||||
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||
transcription = processor.batch_decode(logits.numpy()).text
|
||||
|
||||
self.assertIn("Falling back to sequential decoding.", cl.out)
|
||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
|
||||
def test_inference_diarization(self):
|
||||
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd")
|
||||
|
||||
@@ -25,6 +25,7 @@ import numpy as np
|
||||
from datasets import load_dataset
|
||||
from packaging import version
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoProcessor
|
||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
@@ -194,7 +195,8 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
|
||||
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
|
||||
|
||||
def test_decoder_batch(self):
|
||||
@parameterized.expand([[None], ["fork"], ["spawn"]])
|
||||
def test_decoder_batch(self, pool_context):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
decoder = self.get_decoder()
|
||||
@@ -203,17 +205,25 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
|
||||
logits = self._get_dummy_logits()
|
||||
|
||||
decoded_processor = processor.batch_decode(logits)
|
||||
# note: pool should be instantiated *after* Wav2Vec2ProcessorWithLM.
|
||||
# otherwise, the LM won't be available to the pool's sub-processes.
|
||||
# manual logic used to allow parameterized test for both pool=None and pool=Pool(...)
|
||||
if pool_context is None:
|
||||
decoded_processor = processor.batch_decode(logits)
|
||||
else:
|
||||
with get_context(pool_context).Pool() as pool:
|
||||
decoded_processor = processor.batch_decode(logits, pool)
|
||||
|
||||
logits_list = [array for array in logits]
|
||||
pool = get_context("fork").Pool()
|
||||
decoded_beams = decoder.decode_beams_batch(pool, logits_list)
|
||||
|
||||
with get_context("fork").Pool() as p:
|
||||
decoded_beams = decoder.decode_beams_batch(p, logits_list)
|
||||
|
||||
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
|
||||
for beams in decoded_beams:
|
||||
texts_decoder.append(beams[0][0])
|
||||
logit_scores_decoder.append(beams[0][-2])
|
||||
lm_scores_decoder.append(beams[0][-1])
|
||||
pool.close()
|
||||
|
||||
self.assertListEqual(texts_decoder, decoded_processor.text)
|
||||
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text)
|
||||
@@ -242,15 +252,15 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
decoded_processor = decoded_processor_out.text
|
||||
|
||||
logits_list = [array for array in logits]
|
||||
pool = get_context("fork").Pool()
|
||||
decoded_decoder_out = decoder.decode_beams_batch(
|
||||
pool,
|
||||
logits_list,
|
||||
beam_width=beam_width,
|
||||
beam_prune_logp=beam_prune_logp,
|
||||
token_min_logp=token_min_logp,
|
||||
)
|
||||
pool.close()
|
||||
|
||||
with get_context("fork").Pool() as pool:
|
||||
decoded_decoder_out = decoder.decode_beams_batch(
|
||||
pool,
|
||||
logits_list,
|
||||
beam_width=beam_width,
|
||||
beam_prune_logp=beam_prune_logp,
|
||||
token_min_logp=token_min_logp,
|
||||
)
|
||||
|
||||
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
|
||||
|
||||
@@ -287,12 +297,12 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
unk_score_offset=unk_score_offset,
|
||||
lm_score_boundary=lm_score_boundary,
|
||||
)
|
||||
pool = get_context("fork").Pool()
|
||||
decoded_decoder_out = decoder.decode_beams_batch(
|
||||
pool,
|
||||
logits_list,
|
||||
)
|
||||
pool.close()
|
||||
|
||||
with get_context("fork").Pool() as pool:
|
||||
decoded_decoder_out = decoder.decode_beams_batch(
|
||||
pool,
|
||||
logits_list,
|
||||
)
|
||||
|
||||
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user