Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode (#18351)

* [Wav2Vec2] Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode

* [Wav2Vec2] Add user-managed LM's pool tests and usage examples

* Improve styling

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* [Wav2Vec2] Fix hyperlink references

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Antonio Carlos Falcão Petri
2022-10-18 09:48:03 -03:00
committed by GitHub
parent bf0e094142
commit af150e4a1c
10 changed files with 348 additions and 63 deletions

View File

@@ -25,6 +25,7 @@ import numpy as np
from datasets import load_dataset
from packaging import version
from parameterized import parameterized
from transformers import AutoProcessor
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
@@ -194,7 +195,8 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
def test_decoder_batch(self):
@parameterized.expand([[None], ["fork"], ["spawn"]])
def test_decoder_batch(self, pool_context):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
decoder = self.get_decoder()
@@ -203,17 +205,25 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits = self._get_dummy_logits()
decoded_processor = processor.batch_decode(logits)
# note: pool should be instantiated *after* Wav2Vec2ProcessorWithLM.
# otherwise, the LM won't be available to the pool's sub-processes.
# manual logic used to allow parameterized test for both pool=None and pool=Pool(...)
if pool_context is None:
decoded_processor = processor.batch_decode(logits)
else:
with get_context(pool_context).Pool() as pool:
decoded_processor = processor.batch_decode(logits, pool)
logits_list = [array for array in logits]
pool = get_context("fork").Pool()
decoded_beams = decoder.decode_beams_batch(pool, logits_list)
with get_context("fork").Pool() as p:
decoded_beams = decoder.decode_beams_batch(p, logits_list)
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
for beams in decoded_beams:
texts_decoder.append(beams[0][0])
logit_scores_decoder.append(beams[0][-2])
lm_scores_decoder.append(beams[0][-1])
pool.close()
self.assertListEqual(texts_decoder, decoded_processor.text)
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text)
@@ -242,15 +252,15 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_processor = decoded_processor_out.text
logits_list = [array for array in logits]
pool = get_context("fork").Pool()
decoded_decoder_out = decoder.decode_beams_batch(
pool,
logits_list,
beam_width=beam_width,
beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp,
)
pool.close()
with get_context("fork").Pool() as pool:
decoded_decoder_out = decoder.decode_beams_batch(
pool,
logits_list,
beam_width=beam_width,
beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp,
)
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
@@ -287,12 +297,12 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
unk_score_offset=unk_score_offset,
lm_score_boundary=lm_score_boundary,
)
pool = get_context("fork").Pool()
decoded_decoder_out = decoder.decode_beams_batch(
pool,
logits_list,
)
pool.close()
with get_context("fork").Pool() as pool:
decoded_decoder_out = decoder.decode_beams_batch(
pool,
logits_list,
)
decoded_decoder = [d[0][0] for d in decoded_decoder_out]