[Wav2Vec2ProcessorWithLM] add alpha & beta to batch decode & decode (#15465)
This commit is contained in:
committed by
GitHub
parent
1d94d57546
commit
d718c0c3a8
@@ -253,6 +253,10 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
token_min_logp: Optional[float] = None,
|
token_min_logp: Optional[float] = None,
|
||||||
hotwords: Optional[Iterable[str]] = None,
|
hotwords: Optional[Iterable[str]] = None,
|
||||||
hotword_weight: Optional[float] = None,
|
hotword_weight: Optional[float] = None,
|
||||||
|
alpha: Optional[float] = None,
|
||||||
|
beta: Optional[float] = None,
|
||||||
|
unk_score_offset: Optional[float] = None,
|
||||||
|
lm_score_boundary: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Batch decode output logits to audio transcription with language model support.
|
Batch decode output logits to audio transcription with language model support.
|
||||||
@@ -280,6 +284,14 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
List of words with extra importance, can be OOV for LM
|
List of words with extra importance, can be OOV for LM
|
||||||
hotword_weight (`int`, *optional*):
|
hotword_weight (`int`, *optional*):
|
||||||
Weight factor for hotword importance Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
|
Weight factor for hotword importance Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
|
||||||
|
alpha (`float`, *optional*):
|
||||||
|
Weight for language model during shallow fusion
|
||||||
|
beta (`float`, *optional*):
|
||||||
|
Weight for length score adjustment of during scoring
|
||||||
|
unk_score_offset (`float`, *optional*):
|
||||||
|
Amount of log score offset for unknown tokens
|
||||||
|
lm_score_boundary (`bool`, *optional*):
|
||||||
|
Whether to have kenlm respect boundaries when scoring
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
||||||
@@ -298,6 +310,11 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP
|
token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP
|
||||||
hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT
|
hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT
|
||||||
|
|
||||||
|
# reset params at every forward call. It's just a `set` method in pyctcdecode
|
||||||
|
self.decoder.reset_params(
|
||||||
|
alpha=alpha, beta=beta, unk_score_offset=unk_score_offset, lm_score_boundary=lm_score_boundary
|
||||||
|
)
|
||||||
|
|
||||||
# create multiprocessing pool and list numpy arrays
|
# create multiprocessing pool and list numpy arrays
|
||||||
logits_list = [array for array in logits]
|
logits_list = [array for array in logits]
|
||||||
pool = get_context("fork").Pool(num_processes)
|
pool = get_context("fork").Pool(num_processes)
|
||||||
@@ -330,6 +347,10 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
token_min_logp: Optional[float] = None,
|
token_min_logp: Optional[float] = None,
|
||||||
hotwords: Optional[Iterable[str]] = None,
|
hotwords: Optional[Iterable[str]] = None,
|
||||||
hotword_weight: Optional[float] = None,
|
hotword_weight: Optional[float] = None,
|
||||||
|
alpha: Optional[float] = None,
|
||||||
|
beta: Optional[float] = None,
|
||||||
|
unk_score_offset: Optional[float] = None,
|
||||||
|
lm_score_boundary: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Decode output logits to audio transcription with language model support.
|
Decode output logits to audio transcription with language model support.
|
||||||
@@ -349,6 +370,14 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
List of words with extra importance which can be missing from the LM's vocabulary, e.g. ["huggingface"]
|
List of words with extra importance which can be missing from the LM's vocabulary, e.g. ["huggingface"]
|
||||||
hotword_weight (`int`, *optional*):
|
hotword_weight (`int`, *optional*):
|
||||||
Weight multiplier that boosts hotword scores. Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
|
Weight multiplier that boosts hotword scores. Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
|
||||||
|
alpha (`float`, *optional*):
|
||||||
|
Weight for language model during shallow fusion
|
||||||
|
beta (`float`, *optional*):
|
||||||
|
Weight for length score adjustment of during scoring
|
||||||
|
unk_score_offset (`float`, *optional*):
|
||||||
|
Amount of log score offset for unknown tokens
|
||||||
|
lm_score_boundary (`bool`, *optional*):
|
||||||
|
Whether to have kenlm respect boundaries when scoring
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
|
||||||
@@ -367,6 +396,11 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP
|
token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP
|
||||||
hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT
|
hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT
|
||||||
|
|
||||||
|
# reset params at every forward call. It's just a `set` method in pyctcdecode
|
||||||
|
self.decoder.reset_params(
|
||||||
|
alpha=alpha, beta=beta, unk_score_offset=unk_score_offset, lm_score_boundary=lm_score_boundary
|
||||||
|
)
|
||||||
|
|
||||||
# pyctcdecode
|
# pyctcdecode
|
||||||
decoded_beams = self.decoder.decode_beams(
|
decoded_beams = self.decoder.decode_beams(
|
||||||
logits,
|
logits,
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from multiprocessing import Pool
|
from multiprocessing import get_context
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -196,7 +196,9 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||||||
decoded_processor = processor.batch_decode(logits).text
|
decoded_processor = processor.batch_decode(logits).text
|
||||||
|
|
||||||
logits_list = [array for array in logits]
|
logits_list = [array for array in logits]
|
||||||
decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(Pool(), logits_list)]
|
pool = get_context("fork").Pool()
|
||||||
|
decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(pool, logits_list)]
|
||||||
|
pool.close()
|
||||||
|
|
||||||
self.assertListEqual(decoded_decoder, decoded_processor)
|
self.assertListEqual(decoded_decoder, decoded_processor)
|
||||||
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor)
|
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor)
|
||||||
@@ -223,19 +225,68 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||||||
decoded_processor = decoded_processor_out.text
|
decoded_processor = decoded_processor_out.text
|
||||||
|
|
||||||
logits_list = [array for array in logits]
|
logits_list = [array for array in logits]
|
||||||
|
pool = get_context("fork").Pool()
|
||||||
decoded_decoder_out = decoder.decode_beams_batch(
|
decoded_decoder_out = decoder.decode_beams_batch(
|
||||||
Pool(),
|
pool,
|
||||||
logits_list,
|
logits_list,
|
||||||
beam_width=beam_width,
|
beam_width=beam_width,
|
||||||
beam_prune_logp=beam_prune_logp,
|
beam_prune_logp=beam_prune_logp,
|
||||||
token_min_logp=token_min_logp,
|
token_min_logp=token_min_logp,
|
||||||
)
|
)
|
||||||
|
pool.close()
|
||||||
|
|
||||||
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
|
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
|
||||||
|
|
||||||
self.assertListEqual(decoded_decoder, decoded_processor)
|
self.assertListEqual(decoded_decoder, decoded_processor)
|
||||||
self.assertListEqual(["<s> </s> </s>", "<s> <s> </s>"], decoded_processor)
|
self.assertListEqual(["<s> </s> </s>", "<s> <s> </s>"], decoded_processor)
|
||||||
|
|
||||||
|
def test_decoder_with_params_of_lm(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
decoder = self.get_decoder()
|
||||||
|
|
||||||
|
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
|
||||||
|
|
||||||
|
logits = self._get_dummy_logits()
|
||||||
|
|
||||||
|
alpha = 2.0
|
||||||
|
beta = 5.0
|
||||||
|
unk_score_offset = -20.0
|
||||||
|
lm_score_boundary = True
|
||||||
|
|
||||||
|
decoded_processor_out = processor.batch_decode(
|
||||||
|
logits,
|
||||||
|
alpha=alpha,
|
||||||
|
beta=beta,
|
||||||
|
unk_score_offset=unk_score_offset,
|
||||||
|
lm_score_boundary=lm_score_boundary,
|
||||||
|
)
|
||||||
|
decoded_processor = decoded_processor_out.text
|
||||||
|
|
||||||
|
logits_list = [array for array in logits]
|
||||||
|
decoder.reset_params(
|
||||||
|
alpha=alpha,
|
||||||
|
beta=beta,
|
||||||
|
unk_score_offset=unk_score_offset,
|
||||||
|
lm_score_boundary=lm_score_boundary,
|
||||||
|
)
|
||||||
|
pool = get_context("fork").Pool()
|
||||||
|
decoded_decoder_out = decoder.decode_beams_batch(
|
||||||
|
pool,
|
||||||
|
logits_list,
|
||||||
|
)
|
||||||
|
pool.close()
|
||||||
|
|
||||||
|
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
|
||||||
|
|
||||||
|
self.assertListEqual(decoded_decoder, decoded_processor)
|
||||||
|
self.assertListEqual(["<s> </s> <s> </s> </s>", "</s> </s> <s> </s> </s>"], decoded_processor)
|
||||||
|
lm_model = processor.decoder.model_container[processor.decoder._model_key]
|
||||||
|
self.assertEqual(lm_model.alpha, 2.0)
|
||||||
|
self.assertEqual(lm_model.beta, 5.0)
|
||||||
|
self.assertEqual(lm_model.unk_score_offset, -20.0)
|
||||||
|
self.assertEqual(lm_model.score_boundary, True)
|
||||||
|
|
||||||
def test_decoder_download_ignores_files(self):
|
def test_decoder_download_ignores_files(self):
|
||||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user