From 67047b86ce7143888d2f29d9d067f7527e09ab70 Mon Sep 17 00:00:00 2001
From: arampacha <69807323+arampacha@users.noreply.github.com>
Date: Tue, 15 Feb 2022 17:40:50 +0200
Subject: [PATCH] add scores to Wav2Vec2WithLMOutput (#15413)
* add scores to Wav2Vec2WithLMOutput
* style fixup
---
.../processing_wav2vec2_with_lm.py | 24 +++++++++++++-----
tests/test_processor_wav2vec2_with_lm.py | 25 +++++++++++++------
2 files changed, 35 insertions(+), 14 deletions(-)
diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
index 4947ce39ae..1c0fb5d0bf 100644
--- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
+++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
@@ -42,9 +42,15 @@ class Wav2Vec2DecoderWithLMOutput(ModelOutput):
Args:
text (list of `str`):
Decoded logits in text from. Usually the speech transcription.
+ logit_score (list of `float`):
+ Total logit score of the beam associated with produced text.
+ lm_score (list of `float`):
+ Fused lm_score of the beam associated with produced text.
"""
text: Union[List[str], str]
+ logit_score: Union[List[float], float] = None
+ lm_score: Union[List[float], float] = None
class Wav2Vec2ProcessorWithLM(ProcessorMixin):
@@ -283,7 +289,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
)
# create multiprocessing pool and list numpy arrays
- logits_list = [array for array in logits]
+ # filter out logits padding
+ logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits]
pool = get_context("fork").Pool(num_processes)
# pyctcdecode
@@ -300,11 +307,14 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
# clone multi-processing pool
pool.close()
- # extract text
- batch_texts = [d[0][0] for d in decoded_beams]
-
+ # extract text and scores
+ batch_texts, logit_scores, lm_scores = [], [], []
+ for d in decoded_beams:
+ batch_texts.append(d[0][0])
+ logit_scores.append(d[0][-2])
+ lm_scores.append(d[0][-1])
# more output features will be added in the future
- return Wav2Vec2DecoderWithLMOutput(text=batch_texts)
+ return Wav2Vec2DecoderWithLMOutput(text=batch_texts, logit_score=logit_scores, lm_score=lm_scores)
def decode(
self,
@@ -379,7 +389,9 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
)
# more output features will be added in the future
- return Wav2Vec2DecoderWithLMOutput(text=decoded_beams[0][0])
+ return Wav2Vec2DecoderWithLMOutput(
+ text=decoded_beams[0][0], logit_score=decoded_beams[0][-2], lm_score=decoded_beams[0][-1]
+ )
@contextmanager
def as_target_processor(self):
diff --git a/tests/test_processor_wav2vec2_with_lm.py b/tests/test_processor_wav2vec2_with_lm.py
index 4d562007b8..b3f4cb3cc7 100644
--- a/tests/test_processor_wav2vec2_with_lm.py
+++ b/tests/test_processor_wav2vec2_with_lm.py
@@ -178,12 +178,14 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits = self._get_dummy_logits(shape=(10, 16), seed=13)
- decoded_processor = processor.decode(logits).text
+ decoded_processor = processor.decode(logits)
- decoded_decoder = decoder.decode_beams(logits)[0][0]
+ decoded_decoder = decoder.decode_beams(logits)[0]
- self.assertEqual(decoded_decoder, decoded_processor)
- self.assertEqual(" ", decoded_processor)
+ self.assertEqual(decoded_decoder[0], decoded_processor.text)
+ self.assertEqual(" ", decoded_processor.text)
+ self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
+ self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
def test_decoder_batch(self):
feature_extractor = self.get_feature_extractor()
@@ -194,15 +196,22 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits = self._get_dummy_logits()
- decoded_processor = processor.batch_decode(logits).text
+ decoded_processor = processor.batch_decode(logits)
logits_list = [array for array in logits]
pool = get_context("fork").Pool()
- decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(pool, logits_list)]
+ decoded_beams = decoder.decode_beams_batch(pool, logits_list)
+ texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
+ for beams in decoded_beams:
+ texts_decoder.append(beams[0][0])
+ logit_scores_decoder.append(beams[0][-2])
+ lm_scores_decoder.append(beams[0][-1])
pool.close()
- self.assertListEqual(decoded_decoder, decoded_processor)
- self.assertListEqual([" ", " "], decoded_processor)
+ self.assertListEqual(texts_decoder, decoded_processor.text)
+ self.assertListEqual([" ", " "], decoded_processor.text)
+ self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score)
+ self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score)
def test_decoder_with_params(self):
feature_extractor = self.get_feature_extractor()