From 14d058b9404b2a0659441038e120508e4a9ae10c Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Tue, 24 Jan 2023 19:27:56 +0100 Subject: [PATCH] [W2V2 with LM] Fix decoder test with params (#21277) --- .../test_processor_wav2vec2_with_lm.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py index 7e0892486e..829c609681 100644 --- a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py +++ b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py @@ -230,7 +230,6 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score) self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score) - @unittest.skip("Fix me Sanchit") def test_decoder_with_params(self): feature_extractor = self.get_feature_extractor() tokenizer = self.get_tokenizer() @@ -240,7 +239,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): logits = self._get_dummy_logits() - beam_width = 20 + beam_width = 15 beam_prune_logp = -20.0 token_min_logp = -4.0 @@ -264,9 +263,17 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ) decoded_decoder = [d[0][0] for d in decoded_decoder_out] + logit_scores = [d[0][2] for d in decoded_decoder_out] + lm_scores = [d[0][3] for d in decoded_decoder_out] self.assertListEqual(decoded_decoder, decoded_processor) - self.assertListEqual([" ", " "], decoded_processor) + self.assertListEqual([" ", " "], decoded_processor) + + self.assertTrue(np.array_equal(logit_scores, decoded_processor_out.logit_score)) + self.assertTrue(np.allclose([-20.054, -18.447], logit_scores, atol=1e-3)) + + self.assertTrue(np.array_equal(lm_scores, decoded_processor_out.lm_score)) + self.assertTrue(np.allclose([-15.554, -13.9474], lm_scores, atol=1e-3)) def test_decoder_with_params_of_lm(self): feature_extractor = self.get_feature_extractor()