Add decoder_kwargs to send to LM on asr pipeline. (#15646)
Co-authored-by: Giuseppe Attanasio <giuseppeattanasio6@gmail.com> Co-authored-by: Giuseppe Attanasio <giuseppeattanasio6@gmail.com>
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -180,7 +180,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
if "stride_length_s" in kwargs:
|
||||
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
|
||||
|
||||
return preprocess_params, {}, {}
|
||||
postprocess_params = {}
|
||||
if "decoder_kwargs" in kwargs:
|
||||
postprocess_params["decoder_kwargs"] = kwargs["decoder_kwargs"]
|
||||
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
|
||||
if isinstance(inputs, str):
|
||||
@@ -319,7 +323,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
extra = model_inputs
|
||||
return {"is_last": is_last, **out, **extra}
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None):
|
||||
if self.type == "ctc_with_lm":
|
||||
final_logits = []
|
||||
for outputs in model_outputs:
|
||||
@@ -334,9 +338,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
right_n = total_n - right
|
||||
logits = logits[:, left:right_n]
|
||||
final_logits.append(logits)
|
||||
if decoder_kwargs is None:
|
||||
decoder_kwargs = {}
|
||||
logits = np.concatenate(final_logits, axis=1)
|
||||
logits = logits.squeeze(0)
|
||||
text = self.decoder.decode_beams(logits)[0][0]
|
||||
text = self.decoder.decode_beams(logits, **decoder_kwargs)[0][0]
|
||||
else:
|
||||
skip_special_tokens = self.type != "ctc"
|
||||
tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1)
|
||||
|
||||
@@ -365,10 +365,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
audio_tiled = np.tile(audio, n_repeats)
|
||||
|
||||
output = speech_recognizer([audio_tiled], batch_size=2)
|
||||
|
||||
self.assertEqual(output, [{"text": ANY(str)}])
|
||||
self.assertEqual(output[0]["text"][:6], "<s> <s")
|
||||
|
||||
# Making sure the argument are passed to the decoder
|
||||
# Since no change happens in the result, check the error comes from
|
||||
# the `decode_beams` function.
|
||||
with self.assertRaises(TypeError) as e:
|
||||
output = speech_recognizer([audio_tiled], decoder_kwargs={"num_beams": 2})
|
||||
self.assertContains(e.msg, "TypeError: decode_beams() got an unexpected keyword argument 'num_beams'")
|
||||
output = speech_recognizer([audio_tiled], decoder_kwargs={"beam_width": 2})
|
||||
|
||||
@require_torch
|
||||
@require_pyctcdecode
|
||||
def test_with_local_lm_fast(self):
|
||||
|
||||
Reference in New Issue
Block a user