add scores to Wav2Vec2WithLMOutput (#15413)
* add scores to Wav2Vec2WithLMOutput * style fixup
This commit is contained in:
@@ -42,9 +42,15 @@ class Wav2Vec2DecoderWithLMOutput(ModelOutput):
|
||||
Args:
|
||||
text (list of `str`):
|
||||
Decoded logits in text from. Usually the speech transcription.
|
||||
logit_score (list of `float`):
|
||||
Total logit score of the beam associated with produced text.
|
||||
lm_score (list of `float`):
|
||||
Fused lm_score of the beam associated with produced text.
|
||||
"""
|
||||
|
||||
text: Union[List[str], str]
|
||||
logit_score: Union[List[float], float] = None
|
||||
lm_score: Union[List[float], float] = None
|
||||
|
||||
|
||||
class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
@@ -283,7 +289,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
)
|
||||
|
||||
# create multiprocessing pool and list numpy arrays
|
||||
logits_list = [array for array in logits]
|
||||
# filter out logits padding
|
||||
logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits]
|
||||
pool = get_context("fork").Pool(num_processes)
|
||||
|
||||
# pyctcdecode
|
||||
@@ -300,11 +307,14 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
# clone multi-processing pool
|
||||
pool.close()
|
||||
|
||||
# extract text
|
||||
batch_texts = [d[0][0] for d in decoded_beams]
|
||||
|
||||
# extract text and scores
|
||||
batch_texts, logit_scores, lm_scores = [], [], []
|
||||
for d in decoded_beams:
|
||||
batch_texts.append(d[0][0])
|
||||
logit_scores.append(d[0][-2])
|
||||
lm_scores.append(d[0][-1])
|
||||
# more output features will be added in the future
|
||||
return Wav2Vec2DecoderWithLMOutput(text=batch_texts)
|
||||
return Wav2Vec2DecoderWithLMOutput(text=batch_texts, logit_score=logit_scores, lm_score=lm_scores)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
@@ -379,7 +389,9 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
)
|
||||
|
||||
# more output features will be added in the future
|
||||
return Wav2Vec2DecoderWithLMOutput(text=decoded_beams[0][0])
|
||||
return Wav2Vec2DecoderWithLMOutput(
|
||||
text=decoded_beams[0][0], logit_score=decoded_beams[0][-2], lm_score=decoded_beams[0][-1]
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def as_target_processor(self):
|
||||
|
||||
@@ -178,12 +178,14 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
|
||||
logits = self._get_dummy_logits(shape=(10, 16), seed=13)
|
||||
|
||||
decoded_processor = processor.decode(logits).text
|
||||
decoded_processor = processor.decode(logits)
|
||||
|
||||
decoded_decoder = decoder.decode_beams(logits)[0][0]
|
||||
decoded_decoder = decoder.decode_beams(logits)[0]
|
||||
|
||||
self.assertEqual(decoded_decoder, decoded_processor)
|
||||
self.assertEqual("</s> <s> </s>", decoded_processor)
|
||||
self.assertEqual(decoded_decoder[0], decoded_processor.text)
|
||||
self.assertEqual("</s> <s> </s>", decoded_processor.text)
|
||||
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
|
||||
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
|
||||
|
||||
def test_decoder_batch(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
@@ -194,15 +196,22 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
|
||||
logits = self._get_dummy_logits()
|
||||
|
||||
decoded_processor = processor.batch_decode(logits).text
|
||||
decoded_processor = processor.batch_decode(logits)
|
||||
|
||||
logits_list = [array for array in logits]
|
||||
pool = get_context("fork").Pool()
|
||||
decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(pool, logits_list)]
|
||||
decoded_beams = decoder.decode_beams_batch(pool, logits_list)
|
||||
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
|
||||
for beams in decoded_beams:
|
||||
texts_decoder.append(beams[0][0])
|
||||
logit_scores_decoder.append(beams[0][-2])
|
||||
lm_scores_decoder.append(beams[0][-1])
|
||||
pool.close()
|
||||
|
||||
self.assertListEqual(decoded_decoder, decoded_processor)
|
||||
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor)
|
||||
self.assertListEqual(texts_decoder, decoded_processor.text)
|
||||
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text)
|
||||
self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score)
|
||||
self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score)
|
||||
|
||||
def test_decoder_with_params(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
Reference in New Issue
Block a user