From d718c0c3a887bcab6acc151b3654bf9f46e61d62 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 2 Feb 2022 12:59:40 +0100 Subject: [PATCH] [Wav2Vec2ProcessorWithLM] add alpha & beta to batch decode & decode (#15465) --- .../processing_wav2vec2_with_lm.py | 34 +++++++++++ tests/test_processor_wav2vec2_with_lm.py | 57 ++++++++++++++++++- 2 files changed, 88 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py index 0c8ac8e098..148e42ec66 100644 --- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py @@ -253,6 +253,10 @@ class Wav2Vec2ProcessorWithLM: token_min_logp: Optional[float] = None, hotwords: Optional[Iterable[str]] = None, hotword_weight: Optional[float] = None, + alpha: Optional[float] = None, + beta: Optional[float] = None, + unk_score_offset: Optional[float] = None, + lm_score_boundary: Optional[bool] = None, ): """ Batch decode output logits to audio transcription with language model support. @@ -280,6 +284,14 @@ class Wav2Vec2ProcessorWithLM: List of words with extra importance, can be OOV for LM hotword_weight (`int`, *optional*): Weight factor for hotword importance Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT. + alpha (`float`, *optional*): + Weight for language model during shallow fusion + beta (`float`, *optional*): + Weight for length score adjustment of during scoring + unk_score_offset (`float`, *optional*): + Amount of log score offset for unknown tokens + lm_score_boundary (`bool`, *optional*): + Whether to have kenlm respect boundaries when scoring Returns: [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`. @@ -298,6 +310,11 @@ class Wav2Vec2ProcessorWithLM: token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT + # reset params at every forward call. It's just a `set` method in pyctcdecode + self.decoder.reset_params( + alpha=alpha, beta=beta, unk_score_offset=unk_score_offset, lm_score_boundary=lm_score_boundary + ) + # create multiprocessing pool and list numpy arrays logits_list = [array for array in logits] pool = get_context("fork").Pool(num_processes) @@ -330,6 +347,10 @@ class Wav2Vec2ProcessorWithLM: token_min_logp: Optional[float] = None, hotwords: Optional[Iterable[str]] = None, hotword_weight: Optional[float] = None, + alpha: Optional[float] = None, + beta: Optional[float] = None, + unk_score_offset: Optional[float] = None, + lm_score_boundary: Optional[bool] = None, ): """ Decode output logits to audio transcription with language model support. @@ -349,6 +370,14 @@ class Wav2Vec2ProcessorWithLM: List of words with extra importance which can be missing from the LM's vocabulary, e.g. ["huggingface"] hotword_weight (`int`, *optional*): Weight multiplier that boosts hotword scores. Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT. + alpha (`float`, *optional*): + Weight for language model during shallow fusion + beta (`float`, *optional*): + Weight for length score adjustment of during scoring + unk_score_offset (`float`, *optional*): + Amount of log score offset for unknown tokens + lm_score_boundary (`bool`, *optional*): + Whether to have kenlm respect boundaries when scoring Returns: [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`. @@ -367,6 +396,11 @@ class Wav2Vec2ProcessorWithLM: token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT + # reset params at every forward call. It's just a `set` method in pyctcdecode + self.decoder.reset_params( + alpha=alpha, beta=beta, unk_score_offset=unk_score_offset, lm_score_boundary=lm_score_boundary + ) + # pyctcdecode decoded_beams = self.decoder.decode_beams( logits, diff --git a/tests/test_processor_wav2vec2_with_lm.py b/tests/test_processor_wav2vec2_with_lm.py index 37c6ff01d9..f918a0894a 100644 --- a/tests/test_processor_wav2vec2_with_lm.py +++ b/tests/test_processor_wav2vec2_with_lm.py @@ -17,7 +17,7 @@ import os import shutil import tempfile import unittest -from multiprocessing import Pool +from multiprocessing import get_context from pathlib import Path import numpy as np @@ -196,7 +196,9 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): decoded_processor = processor.batch_decode(logits).text logits_list = [array for array in logits] - decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(Pool(), logits_list)] + pool = get_context("fork").Pool() + decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(pool, logits_list)] + pool.close() self.assertListEqual(decoded_decoder, decoded_processor) self.assertListEqual([" ", " "], decoded_processor) @@ -223,19 +225,68 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): decoded_processor = decoded_processor_out.text logits_list = [array for array in logits] + pool = get_context("fork").Pool() decoded_decoder_out = decoder.decode_beams_batch( - Pool(), + pool, logits_list, beam_width=beam_width, beam_prune_logp=beam_prune_logp, token_min_logp=token_min_logp, ) + pool.close() decoded_decoder = [d[0][0] for d in decoded_decoder_out] self.assertListEqual(decoded_decoder, decoded_processor) self.assertListEqual([" ", " "], decoded_processor) + def test_decoder_with_params_of_lm(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + decoder = self.get_decoder() + + processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder) + + logits = self._get_dummy_logits() + + alpha = 2.0 + beta = 5.0 + unk_score_offset = -20.0 + lm_score_boundary = True + + decoded_processor_out = processor.batch_decode( + logits, + alpha=alpha, + beta=beta, + unk_score_offset=unk_score_offset, + lm_score_boundary=lm_score_boundary, + ) + decoded_processor = decoded_processor_out.text + + logits_list = [array for array in logits] + decoder.reset_params( + alpha=alpha, + beta=beta, + unk_score_offset=unk_score_offset, + lm_score_boundary=lm_score_boundary, + ) + pool = get_context("fork").Pool() + decoded_decoder_out = decoder.decode_beams_batch( + pool, + logits_list, + ) + pool.close() + + decoded_decoder = [d[0][0] for d in decoded_decoder_out] + + self.assertListEqual(decoded_decoder, decoded_processor) + self.assertListEqual([" ", " "], decoded_processor) + lm_model = processor.decoder.model_container[processor.decoder._model_key] + self.assertEqual(lm_model.alpha, 2.0) + self.assertEqual(lm_model.beta, 5.0) + self.assertEqual(lm_model.unk_score_offset, -20.0) + self.assertEqual(lm_model.score_boundary, True) + def test_decoder_download_ignores_files(self): processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")