add scores to Wav2Vec2WithLMOutput (#15413)
* add scores to Wav2Vec2WithLMOutput * style fixup
This commit is contained in:
@@ -178,12 +178,14 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
|
||||
logits = self._get_dummy_logits(shape=(10, 16), seed=13)
|
||||
|
||||
decoded_processor = processor.decode(logits).text
|
||||
decoded_processor = processor.decode(logits)
|
||||
|
||||
decoded_decoder = decoder.decode_beams(logits)[0][0]
|
||||
decoded_decoder = decoder.decode_beams(logits)[0]
|
||||
|
||||
self.assertEqual(decoded_decoder, decoded_processor)
|
||||
self.assertEqual("</s> <s> </s>", decoded_processor)
|
||||
self.assertEqual(decoded_decoder[0], decoded_processor.text)
|
||||
self.assertEqual("</s> <s> </s>", decoded_processor.text)
|
||||
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
|
||||
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
|
||||
|
||||
def test_decoder_batch(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
@@ -194,15 +196,22 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
|
||||
logits = self._get_dummy_logits()
|
||||
|
||||
decoded_processor = processor.batch_decode(logits).text
|
||||
decoded_processor = processor.batch_decode(logits)
|
||||
|
||||
logits_list = [array for array in logits]
|
||||
pool = get_context("fork").Pool()
|
||||
decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(pool, logits_list)]
|
||||
decoded_beams = decoder.decode_beams_batch(pool, logits_list)
|
||||
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
|
||||
for beams in decoded_beams:
|
||||
texts_decoder.append(beams[0][0])
|
||||
logit_scores_decoder.append(beams[0][-2])
|
||||
lm_scores_decoder.append(beams[0][-1])
|
||||
pool.close()
|
||||
|
||||
self.assertListEqual(decoded_decoder, decoded_processor)
|
||||
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor)
|
||||
self.assertListEqual(texts_decoder, decoded_processor.text)
|
||||
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text)
|
||||
self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score)
|
||||
self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score)
|
||||
|
||||
def test_decoder_with_params(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
Reference in New Issue
Block a user