[Wav2Vec2ProcessorWithLM] add alpha & beta to batch decode & decode (#15465)

This commit is contained in:
Patrick von Platen
2022-02-02 12:59:40 +01:00
committed by GitHub
parent 1d94d57546
commit d718c0c3a8
2 changed files with 88 additions and 3 deletions

View File

@@ -17,7 +17,7 @@ import os
import shutil
import tempfile
import unittest
from multiprocessing import Pool
from multiprocessing import get_context
from pathlib import Path
import numpy as np
@@ -196,7 +196,9 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_processor = processor.batch_decode(logits).text
logits_list = [array for array in logits]
decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(Pool(), logits_list)]
pool = get_context("fork").Pool()
decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(pool, logits_list)]
pool.close()
self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor)
@@ -223,19 +225,68 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_processor = decoded_processor_out.text
logits_list = [array for array in logits]
pool = get_context("fork").Pool()
decoded_decoder_out = decoder.decode_beams_batch(
Pool(),
pool,
logits_list,
beam_width=beam_width,
beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp,
)
pool.close()
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> </s> </s>", "<s> <s> </s>"], decoded_processor)
def test_decoder_with_params_of_lm(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
decoder = self.get_decoder()
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
logits = self._get_dummy_logits()
alpha = 2.0
beta = 5.0
unk_score_offset = -20.0
lm_score_boundary = True
decoded_processor_out = processor.batch_decode(
logits,
alpha=alpha,
beta=beta,
unk_score_offset=unk_score_offset,
lm_score_boundary=lm_score_boundary,
)
decoded_processor = decoded_processor_out.text
logits_list = [array for array in logits]
decoder.reset_params(
alpha=alpha,
beta=beta,
unk_score_offset=unk_score_offset,
lm_score_boundary=lm_score_boundary,
)
pool = get_context("fork").Pool()
decoded_decoder_out = decoder.decode_beams_batch(
pool,
logits_list,
)
pool.close()
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> </s> <s> </s> </s>", "</s> </s> <s> </s> </s>"], decoded_processor)
lm_model = processor.decoder.model_container[processor.decoder._model_key]
self.assertEqual(lm_model.alpha, 2.0)
self.assertEqual(lm_model.beta, 5.0)
self.assertEqual(lm_model.unk_score_offset, -20.0)
self.assertEqual(lm_model.score_boundary, True)
def test_decoder_download_ignores_files(self):
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")