diff --git a/docs/source/en/model_doc/wav2vec2.mdx b/docs/source/en/model_doc/wav2vec2.mdx index eaca36be46..79870f5623 100644 --- a/docs/source/en/model_doc/wav2vec2.mdx +++ b/docs/source/en/model_doc/wav2vec2.mdx @@ -73,6 +73,61 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv - batch_decode - decode +### Decoding multiple audios + +If you are planning to decode multiple batches of audios, you should consider using [`~Wav2Vec2ProcessorWithLM.batch_decode`] and passing an instantiated `multiprocessing.Pool`. +Otherwise, [`~Wav2Vec2ProcessorWithLM.batch_decode`] performance will be slower than calling [`~Wav2Vec2ProcessorWithLM.decode`] for each audio individually, as it internally instantiates a new `Pool` for every call. See the example below: + +```python +>>> # Let's see how to use a user-managed pool for batch decoding multiple audios +>>> from multiprocessing import get_context +>>> from transformers import AutoTokenizer, AutoProcessor, AutoModelForCTC +>>> from datasets import load_dataset +>>> import datasets +>>> import torch + +>>> # import model, feature extractor, tokenizer +>>> model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm").to("cuda") +>>> processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm") + +>>> # load example dataset +>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") +>>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000)) + + +>>> def map_to_array(batch): +... batch["speech"] = batch["audio"]["array"] +... return batch + + +>>> # prepare speech data for batch inference +>>> dataset = dataset.map(map_to_array, remove_columns=["audio"]) + + +>>> def map_to_pred(batch, pool): +... inputs = processor(batch["speech"], sampling_rate=16_000, padding=True, return_tensors="pt") +... inputs = {k: v.to("cuda") for k, v in inputs.items()} + +... with torch.no_grad(): +... logits = model(**inputs).logits + +... transcription = processor.batch_decode(logits.cpu().numpy(), pool).text +... batch["transcription"] = transcription +... return batch + + +>>> # note: pool should be instantiated *after* `Wav2Vec2ProcessorWithLM`. +>>> # otherwise, the LM won't be available to the pool's sub-processes +>>> # select number of processes and batch_size based on number of CPU cores available and on dataset size +>>> with get_context("fork").Pool(processes=2) as pool: +... result = dataset.map( +... map_to_pred, batched=True, batch_size=2, fn_kwargs={"pool": pool}, remove_columns=["speech"] +... ) + +>>> result["transcription"][:2] +['MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL', "NOR IS MISTER COULTER'S MANNER LESS INTERESTING THAN HIS MATTER"] +``` + ## Wav2Vec2 specific outputs [[autodoc]] models.wav2vec2_with_lm.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput diff --git a/setup.py b/setup.py index 13f8b42d97..cb5c50b058 100644 --- a/setup.py +++ b/setup.py @@ -164,7 +164,7 @@ _deps = [ "tokenizers>=0.11.1,!=0.11.3,<0.14", "torch>=1.7,!=1.12.0", "torchaudio", - "pyctcdecode>=0.3.0", + "pyctcdecode>=0.4.0", "tqdm>=4.27", "unidic>=1.0.2", "unidic_lite>=1.0.7", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 48a803fcdf..e55219b796 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -70,7 +70,7 @@ deps = { "tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.14", "torch": "torch>=1.7,!=1.12.0", "torchaudio": "torchaudio", - "pyctcdecode": "pyctcdecode>=0.3.0", + "pyctcdecode": "pyctcdecode>=0.4.0", "tqdm": "tqdm>=4.27", "unidic": "unidic>=1.0.2", "unidic_lite": "unidic_lite>=1.0.7", diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index 1e77959400..8d8406817d 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -442,9 +442,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): - Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better - understand how to make use of `output_word_offsets`. - [`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output. + Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make + use of `output_char_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched + output. @@ -454,9 +454,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): - Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better - understand how to make use of `output_word_offsets`. - [`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output. + Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make + use of `output_word_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched + output. @@ -515,8 +515,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): - Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better - understand how to make use of `output_word_offsets`. + Please take a look at the example below to better understand how to make use of `output_char_offsets`. @@ -526,8 +525,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): - Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better - understand how to make use of `output_word_offsets`. + Please take a look at the example below to better understand how to make use of `output_word_offsets`. diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py index f09b5eb922..5c93b68507 100644 --- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py @@ -17,15 +17,18 @@ Speech processor class for Wav2Vec2 """ import os import warnings -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from dataclasses import dataclass -from multiprocessing import get_context +from multiprocessing import Pool, get_context, get_start_method from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union import numpy as np from ...processing_utils import ProcessorMixin -from ...utils import ModelOutput, requires_backends +from ...utils import ModelOutput, logging, requires_backends + + +logger = logging.get_logger(__name__) if TYPE_CHECKING: @@ -115,7 +118,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): This class method is simply calling Wav2Vec2FeatureExtractor's [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], Wav2Vec2CTCTokenizer's - [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`], and + [`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], and [`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`]. Please refer to the docstrings of the methods above for more information. @@ -280,6 +283,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): def batch_decode( self, logits: np.ndarray, + pool: Optional[Pool] = None, num_processes: Optional[int] = None, beam_width: Optional[int] = None, beam_prune_logp: Optional[float] = None, @@ -297,16 +301,32 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): - This function makes use of Python's multiprocessing. + This function makes use of Python's multiprocessing. Currently, multiprocessing is available only on Unix + systems (see this [issue](https://github.com/kensho-technologies/pyctcdecode/issues/65)). + + If you are decoding multiple batches, consider creating a `Pool` and passing it to `batch_decode`. Otherwise, + `batch_decode` will be very slow since it will create a fresh `Pool` for each call. See usage example below. Args: logits (`np.ndarray`): The logits output vector of the model representing the log probabilities for each token. + pool (`multiprocessing.Pool`, *optional*): + An optional user-managed pool. If not set, one will be automatically created and closed. The pool + should be instantiated *after* `Wav2Vec2ProcessorWithLM`. Otherwise, the LM won't be available to the + pool's sub-processes. + + + + Currently, only pools created with a 'fork' context can be used. If a 'spawn' pool is passed, it will + be ignored and sequential decoding will be used instead. + + + num_processes (`int`, *optional*): - Number of processes on which the function should be parallelized over. Defaults to the number of - available CPUs. + If `pool` is not set, number of processes on which the function should be parallelized over. Defaults + to the number of available CPUs. beam_width (`int`, *optional*): Maximum number of beams at each step in decoding. Defaults to pyctcdecode's DEFAULT_BEAM_WIDTH. beam_prune_logp (`int`, *optional*): @@ -332,17 +352,19 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): - Please take a look at the Example of [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to - better understand how to make use of `output_word_offsets`. - [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] works the same way with batched - output. + Please take a look at the Example of [`~Wav2Vec2ProcessorWithLM.decode`] to better understand how to + make use of `output_word_offsets`. [`~Wav2Vec2ProcessorWithLM.batch_decode`] works the same way with + batched output. Returns: - [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`. + [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`]. + Example: + See [Decoding multiple audios](#decoding-multiple-audios). """ + from pyctcdecode.constants import ( DEFAULT_BEAM_WIDTH, DEFAULT_HOTWORD_WEIGHT, @@ -364,21 +386,41 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): # create multiprocessing pool and list numpy arrays # filter out logits padding logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits] - pool = get_context("fork").Pool(num_processes) + + # create a pool if necessary while also using it as a context manager to close itself + if pool is None: + # fork is safe to use only on Unix, see "Contexts and start methods" section on + # multiprocessing's docs (https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods) + default_context = get_start_method() + + if default_context == "fork": + cm = pool = get_context().Pool(num_processes) + else: + logger.warning( + "Parallel batch decoding is not currently supported in this platform. " + "Falling back to sequential decoding." + ) + cm = nullcontext() + else: + # pool is managed by the user, so we don't need to close it + cm = nullcontext() + + if num_processes is not None: + logger.warning( + "Parameter `num_process` was passed, but it will be ignored since `pool` was also specified." + ) # pyctcdecode - decoded_beams = self.decoder.decode_beams_batch( - pool, - logits_list=logits_list, - beam_width=beam_width, - beam_prune_logp=beam_prune_logp, - token_min_logp=token_min_logp, - hotwords=hotwords, - hotword_weight=hotword_weight, - ) - - # clone multi-processing pool - pool.close() + with cm: + decoded_beams = self.decoder.decode_beams_batch( + pool=pool, + logits_list=logits_list, + beam_width=beam_width, + beam_prune_logp=beam_prune_logp, + token_min_logp=token_min_logp, + hotwords=hotwords, + hotword_weight=hotword_weight, + ) # extract text and scores batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], [] @@ -440,13 +482,12 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): - Please take a look at the example of [`~models.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to - better understand how to make use of `output_word_offsets`. + Please take a look at the example below to better understand how to make use of `output_word_offsets`. Returns: - [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`. + [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`]. Example: diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 3915c3f8a5..569e9975b1 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -99,8 +99,8 @@ class ProcessorMixin(PushToHubMixin): This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and - [`~tokenization_utils_base.PreTrainedTokenizer.save_pretrained`]. Please refer to the docstrings of the methods - above for more information. + [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`]. Please refer to the docstrings of the + methods above for more information. diff --git a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py index aa6781a42e..b5524a4f69 100644 --- a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py @@ -14,6 +14,7 @@ import inspect import math +import multiprocessing import unittest import numpy as np @@ -21,6 +22,7 @@ from datasets import load_dataset from transformers import Wav2Vec2Config, is_flax_available from transformers.testing_utils import ( + CaptureLogger, is_flaky, is_librosa_available, is_pt_flax_cross_test, @@ -53,6 +55,7 @@ if is_flax_available(): if is_pyctcdecode_available(): from transformers import Wav2Vec2ProcessorWithLM + from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm if is_librosa_available(): @@ -554,3 +557,58 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase): transcription = processor.batch_decode(np.array(logits)).text self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") + + @require_pyctcdecode + @require_librosa + def test_wav2vec2_with_lm_pool(self): + ds = load_dataset("common_voice", "es", split="test", streaming=True) + sample = next(iter(ds)) + + resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000) + + model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + + input_values = processor(resampled_audio, return_tensors="np").input_values + + logits = model(input_values).logits + + # test user-managed pool + with multiprocessing.get_context("fork").Pool(2) as pool: + transcription = processor.batch_decode(logits.numpy(), pool).text + + self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") + + # user-managed pool + num_processes should trigger a warning + with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool( + 2 + ) as pool: + transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text + + self.assertIn("num_process", cl.out) + self.assertIn("it will be ignored", cl.out) + + self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") + + @require_pyctcdecode + @require_librosa + def test_wav2vec2_with_lm_invalid_pool(self): + ds = load_dataset("common_voice", "es", split="test", streaming=True) + sample = next(iter(ds)) + + resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000) + + model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + + input_values = processor(resampled_audio, return_tensors="np").input_values + + logits = model(input_values).logits + + # change default start method, which should trigger a warning if different than fork + multiprocessing.set_start_method("spawn") + with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl: + transcription = processor.batch_decode(logits.numpy()).text + + self.assertIn("Falling back to sequential decoding.", cl.out) + self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py index 6ea1919a33..5c2a25f413 100644 --- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py @@ -18,6 +18,7 @@ import copy import glob import inspect import math +import multiprocessing import unittest import numpy as np @@ -26,7 +27,7 @@ from datasets import load_dataset from huggingface_hub import snapshot_download from transformers import Wav2Vec2Config, is_tf_available -from transformers.testing_utils import is_flaky, require_librosa, require_pyctcdecode, require_tf, slow +from transformers.testing_utils import CaptureLogger, is_flaky, require_librosa, require_pyctcdecode, require_tf, slow from transformers.utils import is_librosa_available, is_pyctcdecode_available from ...test_configuration_common import ConfigTester @@ -42,6 +43,7 @@ if is_tf_available(): if is_pyctcdecode_available(): from transformers import Wav2Vec2ProcessorWithLM + from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm if is_librosa_available(): @@ -590,3 +592,56 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): transcription = processor.batch_decode(logits.numpy()).text self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes") + + @require_pyctcdecode + @require_librosa + def test_wav2vec2_with_lm_pool(self): + downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample") + file_path = glob.glob(downloaded_folder + "/*")[0] + sample = librosa.load(file_path, sr=16_000)[0] + + model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + + input_values = processor(sample, return_tensors="tf").input_values + + logits = model(input_values).logits + + # test user-managed pool + with multiprocessing.get_context("fork").Pool(2) as pool: + transcription = processor.batch_decode(logits.numpy(), pool).text + + self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes") + + # user-managed pool + num_processes should trigger a warning + with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool( + 2 + ) as pool: + transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text + + self.assertIn("num_process", cl.out) + self.assertIn("it will be ignored", cl.out) + + self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes") + + @require_pyctcdecode + @require_librosa + def test_wav2vec2_with_lm_invalid_pool(self): + downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample") + file_path = glob.glob(downloaded_folder + "/*")[0] + sample = librosa.load(file_path, sr=16_000)[0] + + model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + + input_values = processor(sample, return_tensors="tf").input_values + + logits = model(input_values).logits + + # change default start method, which should trigger a warning if different than fork + multiprocessing.set_start_method("spawn") + with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl: + transcription = processor.batch_decode(logits.numpy()).text + + self.assertIn("Falling back to sequential decoding.", cl.out) + self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes") diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index 040731472f..4d4fa2981d 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -15,6 +15,7 @@ """ Testing suite for the PyTorch Wav2Vec2 model. """ import math +import multiprocessing import os import pickle import tempfile @@ -25,6 +26,7 @@ from datasets import load_dataset from transformers import Wav2Vec2Config, is_torch_available from transformers.testing_utils import ( + CaptureLogger, is_pt_flax_cross_test, is_pyctcdecode_available, is_torchaudio_available, @@ -74,6 +76,7 @@ if is_torchaudio_available(): if is_pyctcdecode_available(): from transformers import Wav2Vec2ProcessorWithLM + from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm if is_torch_fx_available(): @@ -1611,6 +1614,71 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") + @require_pyctcdecode + @require_torchaudio + def test_wav2vec2_with_lm_pool(self): + ds = load_dataset("common_voice", "es", split="test", streaming=True) + sample = next(iter(ds)) + + resampled_audio = torchaudio.functional.resample( + torch.tensor(sample["audio"]["array"]), 48_000, 16_000 + ).numpy() + + model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to( + torch_device + ) + processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + + input_values = processor(resampled_audio, return_tensors="pt").input_values + + with torch.no_grad(): + logits = model(input_values.to(torch_device)).logits + + # test user-managed pool + with multiprocessing.get_context("fork").Pool(2) as pool: + transcription = processor.batch_decode(logits.numpy(), pool).text + + self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") + + # user-managed pool + num_processes should trigger a warning + with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool( + 2 + ) as pool: + transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text + + self.assertIn("num_process", cl.out) + self.assertIn("it will be ignored", cl.out) + + self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") + + @require_pyctcdecode + @require_torchaudio + def test_wav2vec2_with_lm_invalid_pool(self): + ds = load_dataset("common_voice", "es", split="test", streaming=True) + sample = next(iter(ds)) + + resampled_audio = torchaudio.functional.resample( + torch.tensor(sample["audio"]["array"]), 48_000, 16_000 + ).numpy() + + model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to( + torch_device + ) + processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + + input_values = processor(resampled_audio, return_tensors="pt").input_values + + with torch.no_grad(): + logits = model(input_values.to(torch_device)).logits + + # change default start method, which should trigger a warning if different than fork + multiprocessing.set_start_method("spawn") + with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl: + transcription = processor.batch_decode(logits.numpy()).text + + self.assertIn("Falling back to sequential decoding.", cl.out) + self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes") + def test_inference_diarization(self): model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device) processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd") diff --git a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py index 6bf52d3e1b..11a45b6e1a 100644 --- a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py +++ b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py @@ -25,6 +25,7 @@ import numpy as np from datasets import load_dataset from packaging import version +from parameterized import parameterized from transformers import AutoProcessor from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES @@ -194,7 +195,8 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score) self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score) - def test_decoder_batch(self): + @parameterized.expand([[None], ["fork"], ["spawn"]]) + def test_decoder_batch(self, pool_context): feature_extractor = self.get_feature_extractor() tokenizer = self.get_tokenizer() decoder = self.get_decoder() @@ -203,17 +205,25 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): logits = self._get_dummy_logits() - decoded_processor = processor.batch_decode(logits) + # note: pool should be instantiated *after* Wav2Vec2ProcessorWithLM. + # otherwise, the LM won't be available to the pool's sub-processes. + # manual logic used to allow parameterized test for both pool=None and pool=Pool(...) + if pool_context is None: + decoded_processor = processor.batch_decode(logits) + else: + with get_context(pool_context).Pool() as pool: + decoded_processor = processor.batch_decode(logits, pool) logits_list = [array for array in logits] - pool = get_context("fork").Pool() - decoded_beams = decoder.decode_beams_batch(pool, logits_list) + + with get_context("fork").Pool() as p: + decoded_beams = decoder.decode_beams_batch(p, logits_list) + texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], [] for beams in decoded_beams: texts_decoder.append(beams[0][0]) logit_scores_decoder.append(beams[0][-2]) lm_scores_decoder.append(beams[0][-1]) - pool.close() self.assertListEqual(texts_decoder, decoded_processor.text) self.assertListEqual([" ", " "], decoded_processor.text) @@ -242,15 +252,15 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): decoded_processor = decoded_processor_out.text logits_list = [array for array in logits] - pool = get_context("fork").Pool() - decoded_decoder_out = decoder.decode_beams_batch( - pool, - logits_list, - beam_width=beam_width, - beam_prune_logp=beam_prune_logp, - token_min_logp=token_min_logp, - ) - pool.close() + + with get_context("fork").Pool() as pool: + decoded_decoder_out = decoder.decode_beams_batch( + pool, + logits_list, + beam_width=beam_width, + beam_prune_logp=beam_prune_logp, + token_min_logp=token_min_logp, + ) decoded_decoder = [d[0][0] for d in decoded_decoder_out] @@ -287,12 +297,12 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): unk_score_offset=unk_score_offset, lm_score_boundary=lm_score_boundary, ) - pool = get_context("fork").Pool() - decoded_decoder_out = decoder.decode_beams_batch( - pool, - logits_list, - ) - pool.close() + + with get_context("fork").Pool() as pool: + decoded_decoder_out = decoder.decode_beams_batch( + pool, + logits_list, + ) decoded_decoder = [d[0][0] for d in decoded_decoder_out]