From a3dbbc346763c8eaa49577a448e5b5a2da1428ed Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 15 Feb 2022 17:53:24 +0100 Subject: [PATCH] Add `decoder_kwargs` to send to LM on asr pipeline. (#15646) Co-authored-by: Giuseppe Attanasio Co-authored-by: Giuseppe Attanasio --- .../pipelines/automatic_speech_recognition.py | 14 ++++++++++---- .../test_pipelines_automatic_speech_recognition.py | 9 ++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index e57fb7d5e4..df0c24a5a5 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Dict, Optional, Union import numpy as np @@ -180,7 +180,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): if "stride_length_s" in kwargs: preprocess_params["stride_length_s"] = kwargs["stride_length_s"] - return preprocess_params, {}, {} + postprocess_params = {} + if "decoder_kwargs" in kwargs: + postprocess_params["decoder_kwargs"] = kwargs["decoder_kwargs"] + + return preprocess_params, {}, postprocess_params def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): if isinstance(inputs, str): @@ -319,7 +323,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): extra = model_inputs return {"is_last": is_last, **out, **extra} - def postprocess(self, model_outputs): + def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None): if self.type == "ctc_with_lm": final_logits = [] for outputs in model_outputs: @@ -334,9 +338,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): right_n = total_n - right logits = logits[:, left:right_n] final_logits.append(logits) + if decoder_kwargs is None: + decoder_kwargs = {} logits = np.concatenate(final_logits, axis=1) logits = logits.squeeze(0) - text = self.decoder.decode_beams(logits)[0][0] + text = self.decoder.decode_beams(logits, **decoder_kwargs)[0][0] else: skip_special_tokens = self.type != "ctc" tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1) diff --git a/tests/test_pipelines_automatic_speech_recognition.py b/tests/test_pipelines_automatic_speech_recognition.py index 37ab808e77..5e1adbc27d 100644 --- a/tests/test_pipelines_automatic_speech_recognition.py +++ b/tests/test_pipelines_automatic_speech_recognition.py @@ -365,10 +365,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel audio_tiled = np.tile(audio, n_repeats) output = speech_recognizer([audio_tiled], batch_size=2) - self.assertEqual(output, [{"text": ANY(str)}]) self.assertEqual(output[0]["text"][:6], "