add scores to Wav2Vec2WithLMOutput (#15413)

* add scores to Wav2Vec2WithLMOutput

* style fixup
This commit is contained in:
arampacha
2022-02-15 17:40:50 +02:00
committed by GitHub
parent 45f56580a7
commit 67047b86ce
2 changed files with 35 additions and 14 deletions

View File

@@ -178,12 +178,14 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits = self._get_dummy_logits(shape=(10, 16), seed=13)
decoded_processor = processor.decode(logits).text
decoded_processor = processor.decode(logits)
decoded_decoder = decoder.decode_beams(logits)[0][0]
decoded_decoder = decoder.decode_beams(logits)[0]
self.assertEqual(decoded_decoder, decoded_processor)
self.assertEqual("</s> <s> </s>", decoded_processor)
self.assertEqual(decoded_decoder[0], decoded_processor.text)
self.assertEqual("</s> <s> </s>", decoded_processor.text)
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
def test_decoder_batch(self):
feature_extractor = self.get_feature_extractor()
@@ -194,15 +196,22 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits = self._get_dummy_logits()
decoded_processor = processor.batch_decode(logits).text
decoded_processor = processor.batch_decode(logits)
logits_list = [array for array in logits]
pool = get_context("fork").Pool()
decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(pool, logits_list)]
decoded_beams = decoder.decode_beams_batch(pool, logits_list)
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
for beams in decoded_beams:
texts_decoder.append(beams[0][0])
logit_scores_decoder.append(beams[0][-2])
lm_scores_decoder.append(beams[0][-1])
pool.close()
self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor)
self.assertListEqual(texts_decoder, decoded_processor.text)
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text)
self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score)
self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score)
def test_decoder_with_params(self):
feature_extractor = self.get_feature_extractor()