From 14d058b9404b2a0659441038e120508e4a9ae10c Mon Sep 17 00:00:00 2001
From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Date: Tue, 24 Jan 2023 19:27:56 +0100
Subject: [PATCH] [W2V2 with LM] Fix decoder test with params (#21277)
---
.../test_processor_wav2vec2_with_lm.py | 13 ++++++++++---
1 file changed, 10 insertions(+), 3 deletions(-)
diff --git a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
index 7e0892486e..829c609681 100644
--- a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
+++ b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
@@ -230,7 +230,6 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score)
self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score)
- @unittest.skip("Fix me Sanchit")
def test_decoder_with_params(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
@@ -240,7 +239,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits = self._get_dummy_logits()
- beam_width = 20
+ beam_width = 15
beam_prune_logp = -20.0
token_min_logp = -4.0
@@ -264,9 +263,17 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
)
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
+ logit_scores = [d[0][2] for d in decoded_decoder_out]
+ lm_scores = [d[0][3] for d in decoded_decoder_out]
self.assertListEqual(decoded_decoder, decoded_processor)
- self.assertListEqual([" ", " "], decoded_processor)
+ self.assertListEqual([" ", " "], decoded_processor)
+
+ self.assertTrue(np.array_equal(logit_scores, decoded_processor_out.logit_score))
+ self.assertTrue(np.allclose([-20.054, -18.447], logit_scores, atol=1e-3))
+
+ self.assertTrue(np.array_equal(lm_scores, decoded_processor_out.lm_score))
+ self.assertTrue(np.allclose([-15.554, -13.9474], lm_scores, atol=1e-3))
def test_decoder_with_params_of_lm(self):
feature_extractor = self.get_feature_extractor()