Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode (#18351)
* [Wav2Vec2] Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode * [Wav2Vec2] Add user-managed LM's pool tests and usage examples * Improve styling Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * [Wav2Vec2] Fix hyperlink references Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bf0e094142
commit
af150e4a1c
@@ -73,6 +73,61 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
|
|||||||
- batch_decode
|
- batch_decode
|
||||||
- decode
|
- decode
|
||||||
|
|
||||||
|
### Decoding multiple audios
|
||||||
|
|
||||||
|
If you are planning to decode multiple batches of audios, you should consider using [`~Wav2Vec2ProcessorWithLM.batch_decode`] and passing an instantiated `multiprocessing.Pool`.
|
||||||
|
Otherwise, [`~Wav2Vec2ProcessorWithLM.batch_decode`] performance will be slower than calling [`~Wav2Vec2ProcessorWithLM.decode`] for each audio individually, as it internally instantiates a new `Pool` for every call. See the example below:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> # Let's see how to use a user-managed pool for batch decoding multiple audios
|
||||||
|
>>> from multiprocessing import get_context
|
||||||
|
>>> from transformers import AutoTokenizer, AutoProcessor, AutoModelForCTC
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> import datasets
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> # import model, feature extractor, tokenizer
|
||||||
|
>>> model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm").to("cuda")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
|
||||||
|
|
||||||
|
>>> # load example dataset
|
||||||
|
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
>>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||||
|
|
||||||
|
|
||||||
|
>>> def map_to_array(batch):
|
||||||
|
... batch["speech"] = batch["audio"]["array"]
|
||||||
|
... return batch
|
||||||
|
|
||||||
|
|
||||||
|
>>> # prepare speech data for batch inference
|
||||||
|
>>> dataset = dataset.map(map_to_array, remove_columns=["audio"])
|
||||||
|
|
||||||
|
|
||||||
|
>>> def map_to_pred(batch, pool):
|
||||||
|
... inputs = processor(batch["speech"], sampling_rate=16_000, padding=True, return_tensors="pt")
|
||||||
|
... inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
||||||
|
|
||||||
|
... with torch.no_grad():
|
||||||
|
... logits = model(**inputs).logits
|
||||||
|
|
||||||
|
... transcription = processor.batch_decode(logits.cpu().numpy(), pool).text
|
||||||
|
... batch["transcription"] = transcription
|
||||||
|
... return batch
|
||||||
|
|
||||||
|
|
||||||
|
>>> # note: pool should be instantiated *after* `Wav2Vec2ProcessorWithLM`.
|
||||||
|
>>> # otherwise, the LM won't be available to the pool's sub-processes
|
||||||
|
>>> # select number of processes and batch_size based on number of CPU cores available and on dataset size
|
||||||
|
>>> with get_context("fork").Pool(processes=2) as pool:
|
||||||
|
... result = dataset.map(
|
||||||
|
... map_to_pred, batched=True, batch_size=2, fn_kwargs={"pool": pool}, remove_columns=["speech"]
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> result["transcription"][:2]
|
||||||
|
['MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL', "NOR IS MISTER COULTER'S MANNER LESS INTERESTING THAN HIS MATTER"]
|
||||||
|
```
|
||||||
|
|
||||||
## Wav2Vec2 specific outputs
|
## Wav2Vec2 specific outputs
|
||||||
|
|
||||||
[[autodoc]] models.wav2vec2_with_lm.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
|
[[autodoc]] models.wav2vec2_with_lm.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -164,7 +164,7 @@ _deps = [
|
|||||||
"tokenizers>=0.11.1,!=0.11.3,<0.14",
|
"tokenizers>=0.11.1,!=0.11.3,<0.14",
|
||||||
"torch>=1.7,!=1.12.0",
|
"torch>=1.7,!=1.12.0",
|
||||||
"torchaudio",
|
"torchaudio",
|
||||||
"pyctcdecode>=0.3.0",
|
"pyctcdecode>=0.4.0",
|
||||||
"tqdm>=4.27",
|
"tqdm>=4.27",
|
||||||
"unidic>=1.0.2",
|
"unidic>=1.0.2",
|
||||||
"unidic_lite>=1.0.7",
|
"unidic_lite>=1.0.7",
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ deps = {
|
|||||||
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.14",
|
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.14",
|
||||||
"torch": "torch>=1.7,!=1.12.0",
|
"torch": "torch>=1.7,!=1.12.0",
|
||||||
"torchaudio": "torchaudio",
|
"torchaudio": "torchaudio",
|
||||||
"pyctcdecode": "pyctcdecode>=0.3.0",
|
"pyctcdecode": "pyctcdecode>=0.4.0",
|
||||||
"tqdm": "tqdm>=4.27",
|
"tqdm": "tqdm>=4.27",
|
||||||
"unidic": "unidic>=1.0.2",
|
"unidic": "unidic>=1.0.2",
|
||||||
"unidic_lite": "unidic_lite>=1.0.7",
|
"unidic_lite": "unidic_lite>=1.0.7",
|
||||||
|
|||||||
@@ -442,9 +442,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
|
||||||
understand how to make use of `output_word_offsets`.
|
use of `output_char_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
|
||||||
[`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output.
|
output.
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
@@ -454,9 +454,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
|
||||||
understand how to make use of `output_word_offsets`.
|
use of `output_word_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
|
||||||
[`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output.
|
output.
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
@@ -515,8 +515,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
Please take a look at the example below to better understand how to make use of `output_char_offsets`.
|
||||||
understand how to make use of `output_word_offsets`.
|
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
@@ -526,8 +525,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
Please take a look at the example below to better understand how to make use of `output_word_offsets`.
|
||||||
understand how to make use of `output_word_offsets`.
|
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
|
|||||||
@@ -17,15 +17,18 @@ Speech processor class for Wav2Vec2
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager, nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import get_context
|
from multiprocessing import Pool, get_context, get_start_method
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...processing_utils import ProcessorMixin
|
from ...processing_utils import ProcessorMixin
|
||||||
from ...utils import ModelOutput, requires_backends
|
from ...utils import ModelOutput, logging, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -115,7 +118,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
|
|
||||||
This class method is simply calling Wav2Vec2FeatureExtractor's
|
This class method is simply calling Wav2Vec2FeatureExtractor's
|
||||||
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], Wav2Vec2CTCTokenizer's
|
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], Wav2Vec2CTCTokenizer's
|
||||||
[`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`], and
|
[`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], and
|
||||||
[`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`].
|
[`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`].
|
||||||
|
|
||||||
Please refer to the docstrings of the methods above for more information.
|
Please refer to the docstrings of the methods above for more information.
|
||||||
@@ -280,6 +283,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
def batch_decode(
|
def batch_decode(
|
||||||
self,
|
self,
|
||||||
logits: np.ndarray,
|
logits: np.ndarray,
|
||||||
|
pool: Optional[Pool] = None,
|
||||||
num_processes: Optional[int] = None,
|
num_processes: Optional[int] = None,
|
||||||
beam_width: Optional[int] = None,
|
beam_width: Optional[int] = None,
|
||||||
beam_prune_logp: Optional[float] = None,
|
beam_prune_logp: Optional[float] = None,
|
||||||
@@ -297,16 +301,32 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
This function makes use of Python's multiprocessing.
|
This function makes use of Python's multiprocessing. Currently, multiprocessing is available only on Unix
|
||||||
|
systems (see this [issue](https://github.com/kensho-technologies/pyctcdecode/issues/65)).
|
||||||
|
|
||||||
|
If you are decoding multiple batches, consider creating a `Pool` and passing it to `batch_decode`. Otherwise,
|
||||||
|
`batch_decode` will be very slow since it will create a fresh `Pool` for each call. See usage example below.
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (`np.ndarray`):
|
logits (`np.ndarray`):
|
||||||
The logits output vector of the model representing the log probabilities for each token.
|
The logits output vector of the model representing the log probabilities for each token.
|
||||||
|
pool (`multiprocessing.Pool`, *optional*):
|
||||||
|
An optional user-managed pool. If not set, one will be automatically created and closed. The pool
|
||||||
|
should be instantiated *after* `Wav2Vec2ProcessorWithLM`. Otherwise, the LM won't be available to the
|
||||||
|
pool's sub-processes.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Currently, only pools created with a 'fork' context can be used. If a 'spawn' pool is passed, it will
|
||||||
|
be ignored and sequential decoding will be used instead.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
num_processes (`int`, *optional*):
|
num_processes (`int`, *optional*):
|
||||||
Number of processes on which the function should be parallelized over. Defaults to the number of
|
If `pool` is not set, number of processes on which the function should be parallelized over. Defaults
|
||||||
available CPUs.
|
to the number of available CPUs.
|
||||||
beam_width (`int`, *optional*):
|
beam_width (`int`, *optional*):
|
||||||
Maximum number of beams at each step in decoding. Defaults to pyctcdecode's DEFAULT_BEAM_WIDTH.
|
Maximum number of beams at each step in decoding. Defaults to pyctcdecode's DEFAULT_BEAM_WIDTH.
|
||||||
beam_prune_logp (`int`, *optional*):
|
beam_prune_logp (`int`, *optional*):
|
||||||
@@ -332,17 +352,19 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
Please take a look at the Example of [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
|
Please take a look at the Example of [`~Wav2Vec2ProcessorWithLM.decode`] to better understand how to
|
||||||
better understand how to make use of `output_word_offsets`.
|
make use of `output_word_offsets`. [`~Wav2Vec2ProcessorWithLM.batch_decode`] works the same way with
|
||||||
[`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] works the same way with batched
|
batched output.
|
||||||
output.
|
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
See [Decoding multiple audios](#decoding-multiple-audios).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pyctcdecode.constants import (
|
from pyctcdecode.constants import (
|
||||||
DEFAULT_BEAM_WIDTH,
|
DEFAULT_BEAM_WIDTH,
|
||||||
DEFAULT_HOTWORD_WEIGHT,
|
DEFAULT_HOTWORD_WEIGHT,
|
||||||
@@ -364,11 +386,34 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
# create multiprocessing pool and list numpy arrays
|
# create multiprocessing pool and list numpy arrays
|
||||||
# filter out logits padding
|
# filter out logits padding
|
||||||
logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits]
|
logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits]
|
||||||
pool = get_context("fork").Pool(num_processes)
|
|
||||||
|
# create a pool if necessary while also using it as a context manager to close itself
|
||||||
|
if pool is None:
|
||||||
|
# fork is safe to use only on Unix, see "Contexts and start methods" section on
|
||||||
|
# multiprocessing's docs (https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)
|
||||||
|
default_context = get_start_method()
|
||||||
|
|
||||||
|
if default_context == "fork":
|
||||||
|
cm = pool = get_context().Pool(num_processes)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Parallel batch decoding is not currently supported in this platform. "
|
||||||
|
"Falling back to sequential decoding."
|
||||||
|
)
|
||||||
|
cm = nullcontext()
|
||||||
|
else:
|
||||||
|
# pool is managed by the user, so we don't need to close it
|
||||||
|
cm = nullcontext()
|
||||||
|
|
||||||
|
if num_processes is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Parameter `num_process` was passed, but it will be ignored since `pool` was also specified."
|
||||||
|
)
|
||||||
|
|
||||||
# pyctcdecode
|
# pyctcdecode
|
||||||
|
with cm:
|
||||||
decoded_beams = self.decoder.decode_beams_batch(
|
decoded_beams = self.decoder.decode_beams_batch(
|
||||||
pool,
|
pool=pool,
|
||||||
logits_list=logits_list,
|
logits_list=logits_list,
|
||||||
beam_width=beam_width,
|
beam_width=beam_width,
|
||||||
beam_prune_logp=beam_prune_logp,
|
beam_prune_logp=beam_prune_logp,
|
||||||
@@ -377,9 +422,6 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
hotword_weight=hotword_weight,
|
hotword_weight=hotword_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
# clone multi-processing pool
|
|
||||||
pool.close()
|
|
||||||
|
|
||||||
# extract text and scores
|
# extract text and scores
|
||||||
batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], []
|
batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], []
|
||||||
for d in decoded_beams:
|
for d in decoded_beams:
|
||||||
@@ -440,13 +482,12 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
Please take a look at the example of [`~models.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
|
Please take a look at the example below to better understand how to make use of `output_word_offsets`.
|
||||||
better understand how to make use of `output_word_offsets`.
|
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`].
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
|||||||
@@ -99,8 +99,8 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and
|
This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and
|
||||||
[`~tokenization_utils_base.PreTrainedTokenizer.save_pretrained`]. Please refer to the docstrings of the methods
|
[`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`]. Please refer to the docstrings of the
|
||||||
above for more information.
|
methods above for more information.
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
|
import multiprocessing
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -21,6 +22,7 @@ from datasets import load_dataset
|
|||||||
|
|
||||||
from transformers import Wav2Vec2Config, is_flax_available
|
from transformers import Wav2Vec2Config, is_flax_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
CaptureLogger,
|
||||||
is_flaky,
|
is_flaky,
|
||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
is_pt_flax_cross_test,
|
is_pt_flax_cross_test,
|
||||||
@@ -53,6 +55,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
if is_pyctcdecode_available():
|
if is_pyctcdecode_available():
|
||||||
from transformers import Wav2Vec2ProcessorWithLM
|
from transformers import Wav2Vec2ProcessorWithLM
|
||||||
|
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
||||||
|
|
||||||
|
|
||||||
if is_librosa_available():
|
if is_librosa_available():
|
||||||
@@ -554,3 +557,58 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
transcription = processor.batch_decode(np.array(logits)).text
|
transcription = processor.batch_decode(np.array(logits)).text
|
||||||
|
|
||||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
|
||||||
|
@require_pyctcdecode
|
||||||
|
@require_librosa
|
||||||
|
def test_wav2vec2_with_lm_pool(self):
|
||||||
|
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||||
|
sample = next(iter(ds))
|
||||||
|
|
||||||
|
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
|
||||||
|
|
||||||
|
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
|
||||||
|
input_values = processor(resampled_audio, return_tensors="np").input_values
|
||||||
|
|
||||||
|
logits = model(input_values).logits
|
||||||
|
|
||||||
|
# test user-managed pool
|
||||||
|
with multiprocessing.get_context("fork").Pool(2) as pool:
|
||||||
|
transcription = processor.batch_decode(logits.numpy(), pool).text
|
||||||
|
|
||||||
|
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
|
||||||
|
# user-managed pool + num_processes should trigger a warning
|
||||||
|
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
|
||||||
|
2
|
||||||
|
) as pool:
|
||||||
|
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
|
||||||
|
|
||||||
|
self.assertIn("num_process", cl.out)
|
||||||
|
self.assertIn("it will be ignored", cl.out)
|
||||||
|
|
||||||
|
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
|
||||||
|
@require_pyctcdecode
|
||||||
|
@require_librosa
|
||||||
|
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||||
|
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||||
|
sample = next(iter(ds))
|
||||||
|
|
||||||
|
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
|
||||||
|
|
||||||
|
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
|
||||||
|
input_values = processor(resampled_audio, return_tensors="np").input_values
|
||||||
|
|
||||||
|
logits = model(input_values).logits
|
||||||
|
|
||||||
|
# change default start method, which should trigger a warning if different than fork
|
||||||
|
multiprocessing.set_start_method("spawn")
|
||||||
|
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||||
|
transcription = processor.batch_decode(logits.numpy()).text
|
||||||
|
|
||||||
|
self.assertIn("Falling back to sequential decoding.", cl.out)
|
||||||
|
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import copy
|
|||||||
import glob
|
import glob
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
|
import multiprocessing
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -26,7 +27,7 @@ from datasets import load_dataset
|
|||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import Wav2Vec2Config, is_tf_available
|
from transformers import Wav2Vec2Config, is_tf_available
|
||||||
from transformers.testing_utils import is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
|
from transformers.testing_utils import CaptureLogger, is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
|
||||||
from transformers.utils import is_librosa_available, is_pyctcdecode_available
|
from transformers.utils import is_librosa_available, is_pyctcdecode_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -42,6 +43,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
if is_pyctcdecode_available():
|
if is_pyctcdecode_available():
|
||||||
from transformers import Wav2Vec2ProcessorWithLM
|
from transformers import Wav2Vec2ProcessorWithLM
|
||||||
|
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
||||||
|
|
||||||
|
|
||||||
if is_librosa_available():
|
if is_librosa_available():
|
||||||
@@ -590,3 +592,56 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
transcription = processor.batch_decode(logits.numpy()).text
|
transcription = processor.batch_decode(logits.numpy()).text
|
||||||
|
|
||||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||||
|
|
||||||
|
@require_pyctcdecode
|
||||||
|
@require_librosa
|
||||||
|
def test_wav2vec2_with_lm_pool(self):
|
||||||
|
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
|
||||||
|
file_path = glob.glob(downloaded_folder + "/*")[0]
|
||||||
|
sample = librosa.load(file_path, sr=16_000)[0]
|
||||||
|
|
||||||
|
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
|
||||||
|
input_values = processor(sample, return_tensors="tf").input_values
|
||||||
|
|
||||||
|
logits = model(input_values).logits
|
||||||
|
|
||||||
|
# test user-managed pool
|
||||||
|
with multiprocessing.get_context("fork").Pool(2) as pool:
|
||||||
|
transcription = processor.batch_decode(logits.numpy(), pool).text
|
||||||
|
|
||||||
|
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||||
|
|
||||||
|
# user-managed pool + num_processes should trigger a warning
|
||||||
|
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
|
||||||
|
2
|
||||||
|
) as pool:
|
||||||
|
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
|
||||||
|
|
||||||
|
self.assertIn("num_process", cl.out)
|
||||||
|
self.assertIn("it will be ignored", cl.out)
|
||||||
|
|
||||||
|
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||||
|
|
||||||
|
@require_pyctcdecode
|
||||||
|
@require_librosa
|
||||||
|
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||||
|
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
|
||||||
|
file_path = glob.glob(downloaded_folder + "/*")[0]
|
||||||
|
sample = librosa.load(file_path, sr=16_000)[0]
|
||||||
|
|
||||||
|
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
|
||||||
|
input_values = processor(sample, return_tensors="tf").input_values
|
||||||
|
|
||||||
|
logits = model(input_values).logits
|
||||||
|
|
||||||
|
# change default start method, which should trigger a warning if different than fork
|
||||||
|
multiprocessing.set_start_method("spawn")
|
||||||
|
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||||
|
transcription = processor.batch_decode(logits.numpy()).text
|
||||||
|
|
||||||
|
self.assertIn("Falling back to sequential decoding.", cl.out)
|
||||||
|
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
""" Testing suite for the PyTorch Wav2Vec2 model. """
|
""" Testing suite for the PyTorch Wav2Vec2 model. """
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -25,6 +26,7 @@ from datasets import load_dataset
|
|||||||
|
|
||||||
from transformers import Wav2Vec2Config, is_torch_available
|
from transformers import Wav2Vec2Config, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
CaptureLogger,
|
||||||
is_pt_flax_cross_test,
|
is_pt_flax_cross_test,
|
||||||
is_pyctcdecode_available,
|
is_pyctcdecode_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
@@ -74,6 +76,7 @@ if is_torchaudio_available():
|
|||||||
|
|
||||||
if is_pyctcdecode_available():
|
if is_pyctcdecode_available():
|
||||||
from transformers import Wav2Vec2ProcessorWithLM
|
from transformers import Wav2Vec2ProcessorWithLM
|
||||||
|
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
|
||||||
|
|
||||||
|
|
||||||
if is_torch_fx_available():
|
if is_torch_fx_available():
|
||||||
@@ -1611,6 +1614,71 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
|
||||||
|
@require_pyctcdecode
|
||||||
|
@require_torchaudio
|
||||||
|
def test_wav2vec2_with_lm_pool(self):
|
||||||
|
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||||
|
sample = next(iter(ds))
|
||||||
|
|
||||||
|
resampled_audio = torchaudio.functional.resample(
|
||||||
|
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
|
||||||
|
).numpy()
|
||||||
|
|
||||||
|
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
|
||||||
|
input_values = processor(resampled_audio, return_tensors="pt").input_values
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(input_values.to(torch_device)).logits
|
||||||
|
|
||||||
|
# test user-managed pool
|
||||||
|
with multiprocessing.get_context("fork").Pool(2) as pool:
|
||||||
|
transcription = processor.batch_decode(logits.numpy(), pool).text
|
||||||
|
|
||||||
|
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
|
||||||
|
# user-managed pool + num_processes should trigger a warning
|
||||||
|
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
|
||||||
|
2
|
||||||
|
) as pool:
|
||||||
|
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text
|
||||||
|
|
||||||
|
self.assertIn("num_process", cl.out)
|
||||||
|
self.assertIn("it will be ignored", cl.out)
|
||||||
|
|
||||||
|
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
|
||||||
|
@require_pyctcdecode
|
||||||
|
@require_torchaudio
|
||||||
|
def test_wav2vec2_with_lm_invalid_pool(self):
|
||||||
|
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||||
|
sample = next(iter(ds))
|
||||||
|
|
||||||
|
resampled_audio = torchaudio.functional.resample(
|
||||||
|
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
|
||||||
|
).numpy()
|
||||||
|
|
||||||
|
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||||
|
|
||||||
|
input_values = processor(resampled_audio, return_tensors="pt").input_values
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(input_values.to(torch_device)).logits
|
||||||
|
|
||||||
|
# change default start method, which should trigger a warning if different than fork
|
||||||
|
multiprocessing.set_start_method("spawn")
|
||||||
|
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
|
||||||
|
transcription = processor.batch_decode(logits.numpy()).text
|
||||||
|
|
||||||
|
self.assertIn("Falling back to sequential decoding.", cl.out)
|
||||||
|
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||||
|
|
||||||
def test_inference_diarization(self):
|
def test_inference_diarization(self):
|
||||||
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
|
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
|
||||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd")
|
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd")
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import numpy as np
|
|||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from parameterized import parameterized
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
||||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||||
@@ -194,7 +195,8 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||||||
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
|
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
|
||||||
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
|
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
|
||||||
|
|
||||||
def test_decoder_batch(self):
|
@parameterized.expand([[None], ["fork"], ["spawn"]])
|
||||||
|
def test_decoder_batch(self, pool_context):
|
||||||
feature_extractor = self.get_feature_extractor()
|
feature_extractor = self.get_feature_extractor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
decoder = self.get_decoder()
|
decoder = self.get_decoder()
|
||||||
@@ -203,17 +205,25 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||||||
|
|
||||||
logits = self._get_dummy_logits()
|
logits = self._get_dummy_logits()
|
||||||
|
|
||||||
|
# note: pool should be instantiated *after* Wav2Vec2ProcessorWithLM.
|
||||||
|
# otherwise, the LM won't be available to the pool's sub-processes.
|
||||||
|
# manual logic used to allow parameterized test for both pool=None and pool=Pool(...)
|
||||||
|
if pool_context is None:
|
||||||
decoded_processor = processor.batch_decode(logits)
|
decoded_processor = processor.batch_decode(logits)
|
||||||
|
else:
|
||||||
|
with get_context(pool_context).Pool() as pool:
|
||||||
|
decoded_processor = processor.batch_decode(logits, pool)
|
||||||
|
|
||||||
logits_list = [array for array in logits]
|
logits_list = [array for array in logits]
|
||||||
pool = get_context("fork").Pool()
|
|
||||||
decoded_beams = decoder.decode_beams_batch(pool, logits_list)
|
with get_context("fork").Pool() as p:
|
||||||
|
decoded_beams = decoder.decode_beams_batch(p, logits_list)
|
||||||
|
|
||||||
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
|
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
|
||||||
for beams in decoded_beams:
|
for beams in decoded_beams:
|
||||||
texts_decoder.append(beams[0][0])
|
texts_decoder.append(beams[0][0])
|
||||||
logit_scores_decoder.append(beams[0][-2])
|
logit_scores_decoder.append(beams[0][-2])
|
||||||
lm_scores_decoder.append(beams[0][-1])
|
lm_scores_decoder.append(beams[0][-1])
|
||||||
pool.close()
|
|
||||||
|
|
||||||
self.assertListEqual(texts_decoder, decoded_processor.text)
|
self.assertListEqual(texts_decoder, decoded_processor.text)
|
||||||
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text)
|
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text)
|
||||||
@@ -242,7 +252,8 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||||||
decoded_processor = decoded_processor_out.text
|
decoded_processor = decoded_processor_out.text
|
||||||
|
|
||||||
logits_list = [array for array in logits]
|
logits_list = [array for array in logits]
|
||||||
pool = get_context("fork").Pool()
|
|
||||||
|
with get_context("fork").Pool() as pool:
|
||||||
decoded_decoder_out = decoder.decode_beams_batch(
|
decoded_decoder_out = decoder.decode_beams_batch(
|
||||||
pool,
|
pool,
|
||||||
logits_list,
|
logits_list,
|
||||||
@@ -250,7 +261,6 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||||||
beam_prune_logp=beam_prune_logp,
|
beam_prune_logp=beam_prune_logp,
|
||||||
token_min_logp=token_min_logp,
|
token_min_logp=token_min_logp,
|
||||||
)
|
)
|
||||||
pool.close()
|
|
||||||
|
|
||||||
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
|
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
|
||||||
|
|
||||||
@@ -287,12 +297,12 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||||||
unk_score_offset=unk_score_offset,
|
unk_score_offset=unk_score_offset,
|
||||||
lm_score_boundary=lm_score_boundary,
|
lm_score_boundary=lm_score_boundary,
|
||||||
)
|
)
|
||||||
pool = get_context("fork").Pool()
|
|
||||||
|
with get_context("fork").Pool() as pool:
|
||||||
decoded_decoder_out = decoder.decode_beams_batch(
|
decoded_decoder_out = decoder.decode_beams_batch(
|
||||||
pool,
|
pool,
|
||||||
logits_list,
|
logits_list,
|
||||||
)
|
)
|
||||||
pool.close()
|
|
||||||
|
|
||||||
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
|
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user