add scores to Wav2Vec2WithLMOutput (#15413)
* add scores to Wav2Vec2WithLMOutput * style fixup
This commit is contained in:
@@ -42,9 +42,15 @@ class Wav2Vec2DecoderWithLMOutput(ModelOutput):
|
|||||||
Args:
|
Args:
|
||||||
text (list of `str`):
|
text (list of `str`):
|
||||||
Decoded logits in text from. Usually the speech transcription.
|
Decoded logits in text from. Usually the speech transcription.
|
||||||
|
logit_score (list of `float`):
|
||||||
|
Total logit score of the beam associated with produced text.
|
||||||
|
lm_score (list of `float`):
|
||||||
|
Fused lm_score of the beam associated with produced text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text: Union[List[str], str]
|
text: Union[List[str], str]
|
||||||
|
logit_score: Union[List[float], float] = None
|
||||||
|
lm_score: Union[List[float], float] = None
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||||
@@ -283,7 +289,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# create multiprocessing pool and list numpy arrays
|
# create multiprocessing pool and list numpy arrays
|
||||||
logits_list = [array for array in logits]
|
# filter out logits padding
|
||||||
|
logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits]
|
||||||
pool = get_context("fork").Pool(num_processes)
|
pool = get_context("fork").Pool(num_processes)
|
||||||
|
|
||||||
# pyctcdecode
|
# pyctcdecode
|
||||||
@@ -300,11 +307,14 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
# clone multi-processing pool
|
# clone multi-processing pool
|
||||||
pool.close()
|
pool.close()
|
||||||
|
|
||||||
# extract text
|
# extract text and scores
|
||||||
batch_texts = [d[0][0] for d in decoded_beams]
|
batch_texts, logit_scores, lm_scores = [], [], []
|
||||||
|
for d in decoded_beams:
|
||||||
|
batch_texts.append(d[0][0])
|
||||||
|
logit_scores.append(d[0][-2])
|
||||||
|
lm_scores.append(d[0][-1])
|
||||||
# more output features will be added in the future
|
# more output features will be added in the future
|
||||||
return Wav2Vec2DecoderWithLMOutput(text=batch_texts)
|
return Wav2Vec2DecoderWithLMOutput(text=batch_texts, logit_score=logit_scores, lm_score=lm_scores)
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
@@ -379,7 +389,9 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# more output features will be added in the future
|
# more output features will be added in the future
|
||||||
return Wav2Vec2DecoderWithLMOutput(text=decoded_beams[0][0])
|
return Wav2Vec2DecoderWithLMOutput(
|
||||||
|
text=decoded_beams[0][0], logit_score=decoded_beams[0][-2], lm_score=decoded_beams[0][-1]
|
||||||
|
)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def as_target_processor(self):
|
def as_target_processor(self):
|
||||||
|
|||||||
@@ -178,12 +178,14 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||||||
|
|
||||||
logits = self._get_dummy_logits(shape=(10, 16), seed=13)
|
logits = self._get_dummy_logits(shape=(10, 16), seed=13)
|
||||||
|
|
||||||
decoded_processor = processor.decode(logits).text
|
decoded_processor = processor.decode(logits)
|
||||||
|
|
||||||
decoded_decoder = decoder.decode_beams(logits)[0][0]
|
decoded_decoder = decoder.decode_beams(logits)[0]
|
||||||
|
|
||||||
self.assertEqual(decoded_decoder, decoded_processor)
|
self.assertEqual(decoded_decoder[0], decoded_processor.text)
|
||||||
self.assertEqual("</s> <s> </s>", decoded_processor)
|
self.assertEqual("</s> <s> </s>", decoded_processor.text)
|
||||||
|
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
|
||||||
|
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
|
||||||
|
|
||||||
def test_decoder_batch(self):
|
def test_decoder_batch(self):
|
||||||
feature_extractor = self.get_feature_extractor()
|
feature_extractor = self.get_feature_extractor()
|
||||||
@@ -194,15 +196,22 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||||||
|
|
||||||
logits = self._get_dummy_logits()
|
logits = self._get_dummy_logits()
|
||||||
|
|
||||||
decoded_processor = processor.batch_decode(logits).text
|
decoded_processor = processor.batch_decode(logits)
|
||||||
|
|
||||||
logits_list = [array for array in logits]
|
logits_list = [array for array in logits]
|
||||||
pool = get_context("fork").Pool()
|
pool = get_context("fork").Pool()
|
||||||
decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(pool, logits_list)]
|
decoded_beams = decoder.decode_beams_batch(pool, logits_list)
|
||||||
|
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
|
||||||
|
for beams in decoded_beams:
|
||||||
|
texts_decoder.append(beams[0][0])
|
||||||
|
logit_scores_decoder.append(beams[0][-2])
|
||||||
|
lm_scores_decoder.append(beams[0][-1])
|
||||||
pool.close()
|
pool.close()
|
||||||
|
|
||||||
self.assertListEqual(decoded_decoder, decoded_processor)
|
self.assertListEqual(texts_decoder, decoded_processor.text)
|
||||||
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor)
|
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text)
|
||||||
|
self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score)
|
||||||
|
self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score)
|
||||||
|
|
||||||
def test_decoder_with_params(self):
|
def test_decoder_with_params(self):
|
||||||
feature_extractor = self.get_feature_extractor()
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
|||||||
Reference in New Issue
Block a user