Add decoder_kwargs to send to LM on asr pipeline. (#15646)

Co-authored-by: Giuseppe Attanasio <giuseppeattanasio6@gmail.com>

Co-authored-by: Giuseppe Attanasio <giuseppeattanasio6@gmail.com>
This commit is contained in:
Nicolas Patry
2022-02-15 17:53:24 +01:00
committed by GitHub
parent cdf19c501d
commit a3dbbc3467
2 changed files with 18 additions and 5 deletions

View File

@@ -365,10 +365,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
audio_tiled = np.tile(audio, n_repeats)
output = speech_recognizer([audio_tiled], batch_size=2)
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
# Making sure the argument are passed to the decoder
# Since no change happens in the result, check the error comes from
# the `decode_beams` function.
with self.assertRaises(TypeError) as e:
output = speech_recognizer([audio_tiled], decoder_kwargs={"num_beams": 2})
self.assertContains(e.msg, "TypeError: decode_beams() got an unexpected keyword argument 'num_beams'")
output = speech_recognizer([audio_tiled], decoder_kwargs={"beam_width": 2})
@require_torch
@require_pyctcdecode
def test_with_local_lm_fast(self):